@@ -184,9 +184,9 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
184184#[ derive( Clone , Copy , PartialEq , Eq , Hash ) ]
185185struct BorrowKey {
186186 /// exclusive range of lowest and highest address covered by array
187- range : ( usize , usize ) ,
187+ range : ( * mut u8 , * mut u8 ) ,
188188 /// the data address on which address computations are based
189- data_ptr : usize ,
189+ data_ptr : * mut u8 ,
190190 /// the greatest common divisor of the strides of the array
191191 gcd_strides : isize ,
192192}
@@ -199,7 +199,7 @@ impl BorrowKey {
199199 {
200200 let range = data_range ( array) ;
201201
202- let data_ptr = array. data ( ) as usize ;
202+ let data_ptr = array. data ( ) as * mut u8 ;
203203 let gcd_strides = gcd_strides ( array. strides ( ) ) ;
204204
205205 Self {
@@ -225,7 +225,7 @@ impl BorrowKey {
225225 // but fails when slicing an array with a step size that does not divide the dimension along that axis.
226226 //
227227 // https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
228- let ptr_diff = abs_diff ( self . data_ptr , other. data_ptr ) as isize ;
228+ let ptr_diff = unsafe { self . data_ptr . offset_from ( other. data_ptr ) . abs ( ) } ;
229229 let gcd_strides = gcd ( self . gcd_strides , other. gcd_strides ) ;
230230
231231 if ptr_diff % gcd_strides != 0 {
@@ -237,7 +237,7 @@ impl BorrowKey {
237237 }
238238}
239239
240- type BorrowFlagsInner = AHashMap < usize , AHashMap < BorrowKey , isize > > ;
240+ type BorrowFlagsInner = AHashMap < * mut u8 , AHashMap < BorrowKey , isize > > ;
241241
242242struct BorrowFlags ( UnsafeCell < Option < BorrowFlagsInner > > ) ;
243243
@@ -253,7 +253,7 @@ impl BorrowFlags {
253253 ( * self . 0 . get ( ) ) . get_or_insert_with ( AHashMap :: new)
254254 }
255255
256- fn acquire ( & self , _py : Python , address : usize , key : BorrowKey ) -> Result < ( ) , BorrowError > {
256+ fn acquire ( & self , _py : Python , address : * mut u8 , key : BorrowKey ) -> Result < ( ) , BorrowError > {
257257 // SAFETY: Having `_py` implies holding the GIL and
258258 // we are not calling into user code which might re-enter this function.
259259 let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
@@ -296,7 +296,7 @@ impl BorrowFlags {
296296 Ok ( ( ) )
297297 }
298298
299- fn release ( & self , _py : Python , address : usize , key : BorrowKey ) {
299+ fn release ( & self , _py : Python , address : * mut u8 , key : BorrowKey ) {
300300 // SAFETY: Having `_py` implies holding the GIL and
301301 // we are not calling into user code which might re-enter this function.
302302 let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
@@ -316,7 +316,12 @@ impl BorrowFlags {
316316 }
317317 }
318318
319- fn acquire_mut ( & self , _py : Python , address : usize , key : BorrowKey ) -> Result < ( ) , BorrowError > {
319+ fn acquire_mut (
320+ & self ,
321+ _py : Python ,
322+ address : * mut u8 ,
323+ key : BorrowKey ,
324+ ) -> Result < ( ) , BorrowError > {
320325 // SAFETY: Having `_py` implies holding the GIL and
321326 // we are not calling into user code which might re-enter this function.
322327 let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
@@ -353,7 +358,7 @@ impl BorrowFlags {
353358 Ok ( ( ) )
354359 }
355360
356- fn release_mut ( & self , _py : Python , address : usize , key : BorrowKey ) {
361+ fn release_mut ( & self , _py : Python , address : * mut u8 , key : BorrowKey ) {
357362 // SAFETY: Having `_py` implies holding the GIL and
358363 // we are not calling into user code which might re-enter this function.
359364 let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
@@ -383,7 +388,7 @@ where
383388 D : Dimension ,
384389{
385390 array : & ' py PyArray < T , D > ,
386- address : usize ,
391+ address : * mut u8 ,
387392 key : BorrowKey ,
388393}
389394
@@ -526,7 +531,7 @@ where
526531 D : Dimension ,
527532{
528533 array : & ' py PyArray < T , D > ,
529- address : usize ,
534+ address : * mut u8 ,
530535 key : BorrowKey ,
531536}
532537
@@ -680,30 +685,35 @@ where
680685 }
681686}
682687
683- fn base_address < T , D > ( array : & PyArray < T , D > ) -> usize {
684- fn inner ( py : Python , mut array : * mut PyArrayObject ) -> usize {
688+ fn base_address < T , D > ( array : & PyArray < T , D > ) -> * mut u8 {
689+ fn inner ( py : Python , mut array : * mut PyArrayObject ) -> * mut u8 {
685690 loop {
686691 let base = unsafe { ( * array) . base } ;
687692
688693 if base. is_null ( ) {
689- return array as usize ;
694+ return array as * mut u8 ;
690695 } else if unsafe { npyffi:: PyArray_Check ( py, base) } != 0 {
691696 array = base as * mut PyArrayObject ;
692697 } else {
693- return base as usize ;
698+ return base as * mut u8 ;
694699 }
695700 }
696701 }
697702
698703 inner ( array. py ( ) , array. as_array_ptr ( ) )
699704}
700705
701- fn data_range < T , D > ( array : & PyArray < T , D > ) -> ( usize , usize )
706+ fn data_range < T , D > ( array : & PyArray < T , D > ) -> ( * mut u8 , * mut u8 )
702707where
703708 T : Element ,
704709 D : Dimension ,
705710{
706- fn inner ( shape : & [ usize ] , strides : & [ isize ] , itemsize : isize , data : * mut u8 ) -> ( usize , usize ) {
711+ fn inner (
712+ shape : & [ usize ] ,
713+ strides : & [ isize ] ,
714+ itemsize : isize ,
715+ data : * mut u8 ,
716+ ) -> ( * mut u8 , * mut u8 ) {
707717 let mut start = 0 ;
708718 let mut end = 0 ;
709719
@@ -721,33 +731,24 @@ where
721731 end += itemsize;
722732 }
723733
724- let start = unsafe { data. offset ( start) } as usize ;
725- let end = unsafe { data. offset ( end) } as usize ;
734+ let start = unsafe { data. offset ( start) } ;
735+ let end = unsafe { data. offset ( end) } ;
726736
727737 ( start, end)
728738 }
729739
730740 inner (
731741 array. shape ( ) ,
732742 array. strides ( ) ,
733- size_of :: < T > ( ) as _ ,
734- array. data ( ) as _ ,
743+ size_of :: < T > ( ) as isize ,
744+ array. data ( ) as * mut u8 ,
735745 )
736746}
737747
738748fn gcd_strides ( strides : & [ isize ] ) -> isize {
739749 reduce ( strides. iter ( ) . copied ( ) , gcd) . unwrap_or ( 1 )
740750}
741751
742- // FIXME(adamreichold): Use `usize::abs_diff` from std when our MSRV reaches 1.60.
743- fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
744- if lhs >= rhs {
745- lhs - rhs
746- } else {
747- rhs - lhs
748- }
749- }
750-
751752// FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
752753fn reduce < I , F > ( mut iter : I , f : F ) -> Option < I :: Item >
753754where
@@ -777,11 +778,11 @@ mod tests {
777778 assert ! ( base. is_null( ) ) ;
778779
779780 let base_address = base_address ( array) ;
780- assert_eq ! ( base_address, array as * const _ as usize ) ;
781+ assert_eq ! ( base_address, array as * const _ as * mut u8 ) ;
781782
782783 let data_range = data_range ( array) ;
783- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
784- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 6 ) } as usize ) ;
784+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
785+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 6 ) } as * mut u8 ) ;
785786 } ) ;
786787 }
787788
@@ -794,12 +795,12 @@ mod tests {
794795 assert ! ( !base. is_null( ) ) ;
795796
796797 let base_address = base_address ( array) ;
797- assert_ne ! ( base_address, array as * const _ as usize ) ;
798- assert_eq ! ( base_address, base as usize ) ;
798+ assert_ne ! ( base_address, array as * const _ as * mut u8 ) ;
799+ assert_eq ! ( base_address, base as * mut u8 ) ;
799800
800801 let data_range = data_range ( array) ;
801- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
802- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 6 ) } as usize ) ;
802+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
803+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 6 ) } as * mut u8 ) ;
803804 } ) ;
804805 }
805806
@@ -814,18 +815,18 @@ mod tests {
814815 . unwrap ( )
815816 . downcast :: < PyArray2 < f64 > > ( )
816817 . unwrap ( ) ;
817- assert_ne ! ( view as * const _ as usize , array as * const _ as usize ) ;
818+ assert_ne ! ( view as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
818819
819820 let base = unsafe { ( * view. as_array_ptr ( ) ) . base } ;
820- assert_eq ! ( base as usize , array as * const _ as usize ) ;
821+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
821822
822823 let base_address = base_address ( view) ;
823- assert_ne ! ( base_address, view as * const _ as usize ) ;
824- assert_eq ! ( base_address, base as usize ) ;
824+ assert_ne ! ( base_address, view as * const _ as * mut u8 ) ;
825+ assert_eq ! ( base_address, base as * mut u8 ) ;
825826
826827 let data_range = data_range ( view) ;
827- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
828- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 4 ) } as usize ) ;
828+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
829+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 4 ) } as * mut u8 ) ;
829830 } ) ;
830831 }
831832
@@ -840,22 +841,22 @@ mod tests {
840841 . unwrap ( )
841842 . downcast :: < PyArray2 < f64 > > ( )
842843 . unwrap ( ) ;
843- assert_ne ! ( view as * const _ as usize , array as * const _ as usize ) ;
844+ assert_ne ! ( view as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
844845
845846 let base = unsafe { ( * view. as_array_ptr ( ) ) . base } ;
846- assert_eq ! ( base as usize , array as * const _ as usize ) ;
847+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
847848
848849 let base = unsafe { ( * array. as_array_ptr ( ) ) . base } ;
849850 assert ! ( !base. is_null( ) ) ;
850851
851852 let base_address = base_address ( view) ;
852- assert_ne ! ( base_address, view as * const _ as usize ) ;
853- assert_ne ! ( base_address, array as * const _ as usize ) ;
854- assert_eq ! ( base_address, base as usize ) ;
853+ assert_ne ! ( base_address, view as * const _ as * mut u8 ) ;
854+ assert_ne ! ( base_address, array as * const _ as * mut u8 ) ;
855+ assert_eq ! ( base_address, base as * mut u8 ) ;
855856
856857 let data_range = data_range ( view) ;
857- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
858- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 4 ) } as usize ) ;
858+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
859+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 4 ) } as * mut u8 ) ;
859860 } ) ;
860861 }
861862
@@ -870,31 +871,31 @@ mod tests {
870871 . unwrap ( )
871872 . downcast :: < PyArray2 < f64 > > ( )
872873 . unwrap ( ) ;
873- assert_ne ! ( view1 as * const _ as usize , array as * const _ as usize ) ;
874+ assert_ne ! ( view1 as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
874875
875876 let locals = [ ( "view1" , view1) ] . into_py_dict ( py) ;
876877 let view2 = py
877878 . eval ( "view1[:,0]" , None , Some ( locals) )
878879 . unwrap ( )
879880 . downcast :: < PyArray1 < f64 > > ( )
880881 . unwrap ( ) ;
881- assert_ne ! ( view2 as * const _ as usize , array as * const _ as usize ) ;
882- assert_ne ! ( view2 as * const _ as usize , view1 as * const _ as usize ) ;
882+ assert_ne ! ( view2 as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
883+ assert_ne ! ( view2 as * const _ as * mut u8 , view1 as * const _ as * mut u8 ) ;
883884
884885 let base = unsafe { ( * view2. as_array_ptr ( ) ) . base } ;
885- assert_eq ! ( base as usize , array as * const _ as usize ) ;
886+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
886887
887888 let base = unsafe { ( * view1. as_array_ptr ( ) ) . base } ;
888- assert_eq ! ( base as usize , array as * const _ as usize ) ;
889+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
889890
890891 let base_address = base_address ( view2) ;
891- assert_ne ! ( base_address, view2 as * const _ as usize ) ;
892- assert_ne ! ( base_address, view1 as * const _ as usize ) ;
893- assert_eq ! ( base_address, base as usize ) ;
892+ assert_ne ! ( base_address, view2 as * const _ as * mut u8 ) ;
893+ assert_ne ! ( base_address, view1 as * const _ as * mut u8 ) ;
894+ assert_eq ! ( base_address, base as * mut u8 ) ;
894895
895896 let data_range = data_range ( view2) ;
896- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
897- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 1 ) } as usize ) ;
897+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
898+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 1 ) } as * mut u8 ) ;
898899 } ) ;
899900 }
900901
@@ -909,35 +910,35 @@ mod tests {
909910 . unwrap ( )
910911 . downcast :: < PyArray2 < f64 > > ( )
911912 . unwrap ( ) ;
912- assert_ne ! ( view1 as * const _ as usize , array as * const _ as usize ) ;
913+ assert_ne ! ( view1 as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
913914
914915 let locals = [ ( "view1" , view1) ] . into_py_dict ( py) ;
915916 let view2 = py
916917 . eval ( "view1[:,0]" , None , Some ( locals) )
917918 . unwrap ( )
918919 . downcast :: < PyArray1 < f64 > > ( )
919920 . unwrap ( ) ;
920- assert_ne ! ( view2 as * const _ as usize , array as * const _ as usize ) ;
921- assert_ne ! ( view2 as * const _ as usize , view1 as * const _ as usize ) ;
921+ assert_ne ! ( view2 as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
922+ assert_ne ! ( view2 as * const _ as * mut u8 , view1 as * const _ as * mut u8 ) ;
922923
923924 let base = unsafe { ( * view2. as_array_ptr ( ) ) . base } ;
924- assert_eq ! ( base as usize , array as * const _ as usize ) ;
925+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
925926
926927 let base = unsafe { ( * view1. as_array_ptr ( ) ) . base } ;
927- assert_eq ! ( base as usize , array as * const _ as usize ) ;
928+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
928929
929930 let base = unsafe { ( * array. as_array_ptr ( ) ) . base } ;
930931 assert ! ( !base. is_null( ) ) ;
931932
932933 let base_address = base_address ( view2) ;
933- assert_ne ! ( base_address, view2 as * const _ as usize ) ;
934- assert_ne ! ( base_address, view1 as * const _ as usize ) ;
935- assert_ne ! ( base_address, array as * const _ as usize ) ;
936- assert_eq ! ( base_address, base as usize ) ;
934+ assert_ne ! ( base_address, view2 as * const _ as * mut u8 ) ;
935+ assert_ne ! ( base_address, view1 as * const _ as * mut u8 ) ;
936+ assert_ne ! ( base_address, array as * const _ as * mut u8 ) ;
937+ assert_eq ! ( base_address, base as * mut u8 ) ;
937938
938939 let data_range = data_range ( view2) ;
939- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
940- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 1 ) } as usize ) ;
940+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
941+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 1 ) } as * mut u8 ) ;
941942 } ) ;
942943 }
943944
@@ -952,19 +953,19 @@ mod tests {
952953 . unwrap ( )
953954 . downcast :: < PyArray3 < f64 > > ( )
954955 . unwrap ( ) ;
955- assert_ne ! ( view as * const _ as usize , array as * const _ as usize ) ;
956+ assert_ne ! ( view as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
956957
957958 let base = unsafe { ( * view. as_array_ptr ( ) ) . base } ;
958- assert_eq ! ( base as usize , array as * const _ as usize ) ;
959+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
959960
960961 let base_address = base_address ( view) ;
961- assert_ne ! ( base_address, view as * const _ as usize ) ;
962- assert_eq ! ( base_address, base as usize ) ;
962+ assert_ne ! ( base_address, view as * const _ as * mut u8 ) ;
963+ assert_eq ! ( base_address, base as * mut u8 ) ;
963964
964965 let data_range = data_range ( view) ;
965966 assert_eq ! ( view. data( ) , unsafe { array. data( ) . offset( 2 ) } ) ;
966- assert_eq ! ( data_range. 0 , unsafe { view. data( ) . offset( -2 ) } as usize ) ;
967- assert_eq ! ( data_range. 1 , unsafe { view. data( ) . offset( 4 ) } as usize ) ;
967+ assert_eq ! ( data_range. 0 , unsafe { view. data( ) . offset( -2 ) } as * mut u8 ) ;
968+ assert_eq ! ( data_range. 1 , unsafe { view. data( ) . offset( 4 ) } as * mut u8 ) ;
968969 } ) ;
969970 }
970971
@@ -977,11 +978,11 @@ mod tests {
977978 assert ! ( base. is_null( ) ) ;
978979
979980 let base_address = base_address ( array) ;
980- assert_eq ! ( base_address, array as * const _ as usize ) ;
981+ assert_eq ! ( base_address, array as * const _ as * mut u8 ) ;
981982
982983 let data_range = data_range ( array) ;
983- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
984- assert_eq ! ( data_range. 1 , array. data( ) as usize ) ;
984+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
985+ assert_eq ! ( data_range. 1 , array. data( ) as * mut u8 ) ;
985986 } ) ;
986987 }
987988
0 commit comments