@@ -24,8 +24,12 @@ struct ManagerContextMutex<T> where T: 'static {
2424 #[ borrows( array) ]
2525 #[ covariant]
2626 lock : MutexGuard < ' this , Vec < T > > ,
27- shape : i64 ,
28- stride : i64 ,
27+ // Use Box<i64> so that pointers derived via with_*_mut target heap
28+ // memory rather than inline struct fields. This avoids Stacked Borrows
29+ // violations when multiple with_*_mut calls each create exclusive
30+ // reborrows of the ouroboros struct.
31+ shape : Box < i64 > ,
32+ stride : Box < i64 > ,
2933}
3034
3135unsafe extern "C" fn mutex_deleter_fn < T > ( manager : * mut sys:: DLManagedTensorVersioned ) where T : ' static {
@@ -41,24 +45,24 @@ impl<T> TryFrom<Arc<Mutex<Vec<T>>>> for DLPackTensor where T: GetDLPackDataType
4145 let ctx = ManagerContextMutexBuilder {
4246 array : array,
4347 lock_builder : |array| { array. lock ( ) . expect ( "could not lock the mutex" ) } ,
44- shape : 0 ,
45- stride : 1 ,
48+ shape : Box :: new ( 0 ) ,
49+ stride : Box :: new ( 1 ) ,
4650 } ;
4751 let mut ctx = Box :: new ( ctx. build ( ) ) ;
4852
4953 // set the shape after acquiring the lock to avoid deadlocks
5054 let shape = ctx. with_lock ( |lock| lock. len ( ) as i64 ) ;
51- ctx. with_shape_mut ( |v| * v = shape) ;
55+ ctx. with_shape_mut ( |v| * * v = shape) ;
5256
5357 // extract pointers out of the boxed context to use in the DLPack tensor
5458 let mut shape_ptr = std:: ptr:: null_mut ( ) ;
5559 ctx. with_shape_mut ( |shape| {
56- shape_ptr = shape as * mut i64 ;
60+ shape_ptr = shape. as_mut ( ) ;
5761 } ) ;
5862
5963 let mut stride_ptr = std:: ptr:: null_mut ( ) ;
6064 ctx. with_stride_mut ( |stride| {
61- stride_ptr = stride as * mut i64 ;
65+ stride_ptr = stride. as_mut ( ) ;
6266 } ) ;
6367
6468 let mut data = std:: ptr:: null_mut ( ) ;
@@ -102,8 +106,8 @@ struct ManagerContextRwLock<T> where T: 'static {
102106 #[ borrows( array) ]
103107 #[ covariant]
104108 lock : RwLockWriteGuard < ' this , Vec < T > > ,
105- shape : i64 ,
106- stride : i64 ,
109+ shape : Box < i64 > ,
110+ stride : Box < i64 > ,
107111}
108112
109113unsafe extern "C" fn rwlock_deleter_fn < T > ( manager : * mut sys:: DLManagedTensorVersioned ) where T : ' static {
@@ -119,24 +123,24 @@ impl<T> TryFrom<Arc<RwLock<Vec<T>>>> for DLPackTensor where T: GetDLPackDataType
119123 let ctx = ManagerContextRwLockBuilder {
120124 array : array,
121125 lock_builder : move |array| { array. write ( ) . expect ( "could not lock the rwlock" ) } ,
122- shape : 0 ,
123- stride : 1 ,
126+ shape : Box :: new ( 0 ) ,
127+ stride : Box :: new ( 1 ) ,
124128 } ;
125129 let mut ctx = Box :: new ( ctx. build ( ) ) ;
126130
127131 // set the shape after acquiring the lock to avoid deadlocks
128132 let shape = ctx. with_lock ( |lock| lock. len ( ) as i64 ) ;
129- ctx. with_shape_mut ( |v| * v = shape) ;
133+ ctx. with_shape_mut ( |v| * * v = shape) ;
130134
131135 // extract pointers out of the boxed context to use in the DLPack tensor
132136 let mut shape_ptr = std:: ptr:: null_mut ( ) ;
133137 ctx. with_shape_mut ( |shape| {
134- shape_ptr = shape as * mut i64 ;
138+ shape_ptr = shape. as_mut ( ) ;
135139 } ) ;
136140
137141 let mut stride_ptr = std:: ptr:: null_mut ( ) ;
138142 ctx. with_stride_mut ( |stride| {
139- stride_ptr = stride as * mut i64 ;
143+ stride_ptr = stride. as_mut ( ) ;
140144 } ) ;
141145
142146 let mut data = std:: ptr:: null_mut ( ) ;
@@ -215,4 +219,27 @@ mod tests {
215219 let lock = data. read ( ) . unwrap ( ) ;
216220 assert_eq ! ( & * lock, & [ 1 , 42 , 3 ] ) ;
217221 }
222+
223+ // Last-ref tests: the tensor holds the only Arc reference, so dropping
224+ // it actually deallocates the ManagerContext via the deleter function.
225+
226+ #[ test]
227+ fn test_mutex_last_arc_ref ( ) {
228+ let data = Arc :: new ( Mutex :: new ( vec ! [ 1i32 , 2 , 3 ] ) ) ;
229+
230+ let mut tensor: DLPackTensor = data. try_into ( ) . unwrap ( ) ;
231+ let tensor_mut_ref = tensor. as_mut ( ) ;
232+ let slice: & mut [ i32 ] = tensor_mut_ref. try_into ( ) . unwrap ( ) ;
233+ assert_eq ! ( slice, & [ 1 , 2 , 3 ] ) ;
234+ }
235+
236+ #[ test]
237+ fn test_rwlock_last_arc_ref ( ) {
238+ let data = Arc :: new ( RwLock :: new ( vec ! [ 1i32 , 2 , 3 ] ) ) ;
239+
240+ let mut tensor: DLPackTensor = data. try_into ( ) . unwrap ( ) ;
241+ let tensor_mut_ref = tensor. as_mut ( ) ;
242+ let slice: & mut [ i32 ] = tensor_mut_ref. try_into ( ) . unwrap ( ) ;
243+ assert_eq ! ( slice, & [ 1 , 2 , 3 ] ) ;
244+ }
218245}
0 commit comments