1- use numpy:: ndarray:: { ArrayD , ArrayViewD , ArrayViewMutD , Zip } ;
1+ use std:: ops:: Add ;
2+
3+ use numpy:: ndarray:: { Array1 , ArrayD , ArrayView1 , ArrayViewD , ArrayViewMutD , Zip } ;
24use numpy:: {
35 datetime:: { units, Timedelta } ,
46 Complex64 , IntoPyArray , PyArray1 , PyArrayDyn , PyReadonlyArray1 , PyReadonlyArrayDyn ,
@@ -7,7 +9,7 @@ use numpy::{
79use pyo3:: {
810 pymodule,
911 types:: { PyDict , PyModule } ,
10- PyResult , Python ,
12+ FromPyObject , PyAny , PyResult , Python ,
1113} ;
1214
1315#[ pymodule]
@@ -27,6 +29,11 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
2729 x. map ( |c| c. conj ( ) )
2830 }
2931
32+ // example using generics
33+ fn generic_add < T : Copy + Add < Output = T > > ( x : ArrayView1 < T > , y : ArrayView1 < T > ) -> Array1 < T > {
34+ & x + & y
35+ }
36+
3037 // wrapper of `axpy`
3138 #[ pyfn( m) ]
3239 #[ pyo3( name = "axpy" ) ]
@@ -84,5 +91,47 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
8491 . apply ( |x, y| * x = ( i64:: from ( * x) + 60 * i64:: from ( * y) ) . into ( ) ) ;
8592 }
8693
94+ // This crate follows a strongly-typed approach to wrapping NumPy arrays
95+ // while Python API are often expected to work with multiple element types.
96+ //
97+ // That kind of limited polymorphis can be recovered by accepting an enumerated type
98+ // covering the supported element types and dispatching into a generic implementation.
99+ #[ derive( FromPyObject ) ]
100+ enum SupportedArray < ' py > {
101+ F64 ( & ' py PyArray1 < f64 > ) ,
102+ I64 ( & ' py PyArray1 < i64 > ) ,
103+ }
104+
105+ #[ pyfn( m) ]
106+ fn polymorphic_add < ' py > (
107+ x : SupportedArray < ' py > ,
108+ y : SupportedArray < ' py > ,
109+ ) -> PyResult < & ' py PyAny > {
110+ match ( x, y) {
111+ ( SupportedArray :: F64 ( x) , SupportedArray :: F64 ( y) ) => Ok ( generic_add (
112+ x. readonly ( ) . as_array ( ) ,
113+ y. readonly ( ) . as_array ( ) ,
114+ )
115+ . into_pyarray ( x. py ( ) )
116+ . into ( ) ) ,
117+ ( SupportedArray :: I64 ( x) , SupportedArray :: I64 ( y) ) => Ok ( generic_add (
118+ x. readonly ( ) . as_array ( ) ,
119+ y. readonly ( ) . as_array ( ) ,
120+ )
121+ . into_pyarray ( x. py ( ) )
122+ . into ( ) ) ,
123+ ( SupportedArray :: F64 ( x) , SupportedArray :: I64 ( y) )
124+ | ( SupportedArray :: I64 ( y) , SupportedArray :: F64 ( x) ) => {
125+ let y = y. cast :: < f64 > ( false ) ?;
126+
127+ Ok (
128+ generic_add ( x. readonly ( ) . as_array ( ) , y. readonly ( ) . as_array ( ) )
129+ . into_pyarray ( x. py ( ) )
130+ . into ( ) ,
131+ )
132+ }
133+ }
134+ }
135+
87136 Ok ( ( ) )
88137}
0 commit comments