33
44#include < pybind11/pybind11.h>
55#include < pybind11/numpy.h>
6+ #include < pybind11/stl.h>
67
78#include < graphblas.hpp>
89
1415
1516namespace py = pybind11;
1617
18+
1719// Register all pyalp bindings. Module-local registration can be enabled by
1820// instantiating with ModuleLocal = true. When ModuleLocal==true the
1921// py::module_local() policy is applied to class bindings to avoid symbol
@@ -24,7 +26,7 @@ void register_pyalp(py::module_ &m) {
2426 // Common bindings for all backends
2527 m.def (" backend_name" , [](){ return " backend" ; });
2628
27- if constexpr (ModuleLocal) {
29+ if (ModuleLocal) {
2830 py::class_<grb::Matrix< ScalarType >>(m, " Matrix" , py::module_local ())
2931 .def (py::init ([](size_t m_, size_t n_,
3032 py::array data1,
@@ -51,7 +53,20 @@ void register_pyalp(py::module_ &m) {
5153 py::arg (" m" ),
5254 py::arg (" k_array" )
5355 )
54- .def (" to_numpy" , &to_numpy, " Convert to numpy array" );
56+ .def (" to_numpy" , &to_numpy< ScalarType >, " Convert to numpy array" );
57+
58+ py::class_<grb::Vector< StateType >>(m, " State" , py::module_local ())
59+ .def (py::init<size_t >())
60+ .def (py::init ([](size_t m,
61+ py::array_t <StateType> data3) {
62+ grb::Vector< StateType > vec (m); // call the basic constructor
63+ buildVector (vec, data3); // initialize with data
64+ return vec;
65+ }),
66+ py::arg (" m" ),
67+ py::arg (" k_array" )
68+ )
69+ .def (" to_numpy" , &to_numpy< StateType >, " Convert to numpy array" );
5570 } else {
5671 py::class_<grb::Matrix< ScalarType >>(m, " Matrix" )
5772 .def (py::init ([](size_t m_, size_t n_,
@@ -77,10 +92,54 @@ void register_pyalp(py::module_ &m) {
7792 py::arg (" m" ),
7893 py::arg (" k_array" )
7994 )
80- .def (" to_numpy" , &to_numpy, " Convert to numpy array" );
95+ .def (" to_numpy" , &to_numpy< ScalarType >, " Convert to numpy array" );
96+
97+ py::class_<grb::Vector< StateType >>(m, " State" )
98+ .def (py::init<size_t >())
99+ .def (py::init ([](size_t m,
100+ py::array_t <StateType> data3) {
101+ grb::Vector< StateType > vec (m); // call the basic constructor
102+ buildVector (vec, data3); // initialize with data
103+ return vec;
104+ }),
105+ py::arg (" m" ),
106+ py::arg (" k_array" )
107+ )
108+ .def (" to_numpy" , &to_numpy< StateType >, " Convert to numpy array" );
109+
81110 }
111+
112+ py::class_< std::vector<grb::Vector<StateType>> >(m, " stdVectorStates" )
113+ .def (py::init<size_t >())
114+ .def (py::init ([](
115+ size_t m,
116+ py::array_t < StateType > arr ) {
117+ (void ) m;
82118
83- m.def (" buildVector" , &buildVector, " Fill Vector from 1 NumPy array" );
119+ const py::buffer_info buf = arr.request ();
120+
121+ if (buf.ndim != 2 ) {
122+ throw std::runtime_error (" Input array must be 2-dimensional" );
123+ }
124+ const size_t sz = buf.shape [0 ];
125+ std::vector< grb::Vector<StateType> > vec (sz);
126+ assert ( static_cast <size_t >( buf.shape [1 ] ) <= m );
127+ auto ptr = static_cast <StateType*>(buf.ptr );
128+
129+ grb::RC io_rc = grb::SUCCESS;
130+ for (size_t i = 0 ; i < sz; i++) {
131+ io_rc = io_rc ? io_rc :
132+ grb::buildVector ( vec[i], ptr, ptr + buf.shape [1 ], grb::SEQUENTIAL );
133+ }
134+ return vec;
135+ }),
136+ py::arg (" m" ),
137+ py::arg (" k_array" )
138+ );
139+ // .def("get_vec", & );
140+
141+ m.def (" buildVector" , &buildVector<ScalarType>, " Fill Vector from 1 NumPy array" );
142+ m.def (" buildVectorInt8" , &buildVector<StateType>, " Fill Vector from 1 NumPy array" );
84143 m.def (" print_my_numpy_array" , &print_my_numpy_array, " Print a numpy array as a flattened std::vector" );
85144 m.def (" conjugate_gradient" , &conjugate_gradient, " Pass alp data to alp CG solver" ,
86145 py::arg (" L" ),
@@ -114,4 +173,5 @@ void register_pyalp(py::module_ &m) {
114173 py::arg (" seed" ) = 0 ,
115174 py::arg (" verbose" ) = 0
116175 );
176+
117177}
0 commit comments