11#include " IALSLearningConfig.hpp"
22#include " IALSTrainer.hpp"
3- #include " pybind11/cast.h"
43#include < Eigen/Sparse>
54#include < cstddef>
6- #include < pybind11/eigen .h>
7- #include < pybind11/pybind11 .h>
8- #include < pybind11/stl .h>
9- #include < pybind11/stl_bind .h>
10- #include < sstream >
11- #include < stdexcept >
12- #include < vector >
5+ #include < nanobind/nanobind .h>
6+ #include < nanobind/nb_defs .h>
7+ #include < nanobind/eigen/dense .h>
8+ #include < nanobind/eigen/sparse .h>
9+ #include < nanobind/stl/tuple.h >
10+ #include < nanobind/stl/string.h >
11+ #include < tuple >
1312
14- namespace py = pybind11;
1513using namespace irspack ::ials;
16- using std::vector ;
14+ using namespace nanobind ;
1715
18- PYBIND11_MODULE (_ials , m) {
16+ NB_MODULE (_ials_core , m) {
1917 std::stringstream doc_stream;
2018 doc_stream << " irspack's core module for \" IALSRecommender\" ." << std::endl
2119 << " Built to use" << std::endl
2220 << " \t " << Eigen::SimdInstructionSetsInUse ();
2321
2422 m.doc () = doc_stream.str ();
2523
26- py ::enum_<LossType>(m, " LossType" )
24+ nanobind ::enum_<LossType>(m, " LossType" )
2725 .value (" ORIGINAL" , LossType::ORIGINAL)
2826 .value (" IALSPP" , LossType::IALSPP)
2927 .export_values ();
3028
31- py ::enum_<SolverType>(m, " SolverType" )
29+ nanobind ::enum_<SolverType>(m, " SolverType" )
3230 .value (" CHOLESKY" , SolverType::Cholesky)
3331 .value (" CG" , SolverType::CG)
3432 .value (" IALSPP" , SolverType::IALSPP)
3533 .export_values ();
3634
3735 auto model_config =
38- py::class_<IALSModelConfig>(m, " IALSModelConfig" )
39- .def (py::init<size_t , Real, Real, Real, Real, int , LossType>())
40- .def (py::pickle (
41- [](const IALSModelConfig &config) {
42- return py::make_tuple (config.K , config.alpha0 , config.reg ,
43- config.nu , config.init_stdev ,
44- config.random_seed , config.loss_type );
45- },
46- [](py::tuple t) {
47- if (t.size () != 7 )
48- throw std::runtime_error (" invalid state" );
49-
50- size_t K = t[0 ].cast <size_t >();
51- Real alpha0 = t[1 ].cast <Real>();
52- Real reg = t[2 ].cast <Real>();
53- Real nu = t[3 ].cast <Real>();
54- Real init_stdev = t[4 ].cast <Real>();
55- int random_seed = t[5 ].cast <int >();
56- LossType loss_type = t[6 ].cast <LossType>();
57- return IALSModelConfig (K, alpha0, reg, nu, init_stdev,
58- random_seed, loss_type);
59- }));
60- py::class_<IALSModelConfig::Builder>(m, " IALSModelConfigBuilder" )
61- .def (py::init<>())
36+ nanobind::class_<IALSModelConfig>(m, " IALSModelConfig" )
37+ .def (nanobind::init<size_t , Real, Real, Real, Real, int , LossType>())
38+ .def (" __getstate__" ,
39+ [](const IALSModelConfig &config) {
40+ return nanobind::make_tuple (
41+ config.K , config.alpha0 , config.reg , config.nu ,
42+ config.init_stdev , config.random_seed , config.loss_type );
43+ })
44+ .def (" __setstate__" ,
45+ [](IALSModelConfig &ials_model_config,
46+ const std::tuple<size_t , Real, Real, Real, Real, int ,
47+ LossType> &state) {
48+ new (&ials_model_config) IALSModelConfig (
49+ std::get<0 >(state), std::get<1 >(state), std::get<2 >(state),
50+ std::get<3 >(state), std::get<4 >(state), std::get<5 >(state),
51+ std::get<6 >(state));
52+ });
53+ nanobind::class_<IALSModelConfig::Builder>(m, " IALSModelConfigBuilder" )
54+ .def (nanobind::init<>())
6255 .def (" build" , &IALSModelConfig::Builder::build)
6356 .def (" set_K" , &IALSModelConfig::Builder::set_K)
6457 .def (" set_alpha0" , &IALSModelConfig::Builder::set_alpha0)
@@ -69,30 +62,28 @@ PYBIND11_MODULE(_ials, m) {
6962 .def (" set_loss_type" , &IALSModelConfig::Builder::set_loss_type);
7063
7164 auto solver_config =
72- py ::class_<SolverConfig>(m, " IALSSolverConfig" )
73- .def (py ::init<size_t , SolverType, size_t , size_t , size_t >())
74- .def (py::pickle (
75- [](const SolverConfig &config) {
76- return py ::make_tuple (
65+ nanobind ::class_<SolverConfig>(m, " IALSSolverConfig" )
66+ .def (nanobind ::init<size_t , SolverType, size_t , size_t , size_t >())
67+ .def (
68+ " __getstate__ " , [](const SolverConfig &config) {
69+ return std ::make_tuple (
7770 config.n_threads , config.solver_type , config.max_cg_steps ,
7871 config.ialspp_subspace_dimension , config.ialspp_iteration );
79- },
80- [](py::tuple t) {
81- if (t.size () != 5 )
82- throw std::runtime_error (" invalid state" );
83-
84- size_t n_threads = t[0 ].cast <size_t >();
85- SolverType solver_type = t[1 ].cast <SolverType>();
86- size_t max_cg_steps = t[2 ].cast <size_t >();
87- size_t ialspp_subspace_dimension = t[3 ].cast <size_t >();
88- size_t ialspp_iteration = t[4 ].cast <size_t >();
89- return SolverConfig (n_threads, solver_type, max_cg_steps,
90- ialspp_subspace_dimension,
91- ialspp_iteration);
92- }));
72+ }
73+ )
74+ .def (
75+ " __setstate__" , [](SolverConfig &config, const std::tuple<size_t , SolverType, size_t , size_t , size_t > &state) {
76+ new (&config) SolverConfig (
77+ std::get<0 >(state),
78+ std::get<1 >(state),
79+ std::get<2 >(state),
80+ std::get<3 >(state),
81+ std::get<4 >(state));
82+ }
83+ );
9384
94- py ::class_<SolverConfig::Builder>(m, " IALSSolverConfigBuilder" )
95- .def (py ::init<>())
85+ nanobind ::class_<SolverConfig::Builder>(m, " IALSSolverConfigBuilder" )
86+ .def (nanobind ::init<>())
9687 .def (" build" , &SolverConfig::Builder::build)
9788 .def (" set_n_threads" , &SolverConfig::Builder::set_n_threads)
9889 .def (" set_solver_type" , &SolverConfig::Builder::set_solver_type)
@@ -102,25 +93,23 @@ PYBIND11_MODULE(_ials, m) {
10293 .def (" set_ialspp_iteration" ,
10394 &SolverConfig::Builder::set_ialspp_iteration);
10495
105- py ::class_<IALSTrainer>(m, " IALSTrainer" )
106- .def (py ::init<IALSModelConfig, const SparseMatrix &>())
96+ nanobind ::class_<IALSTrainer>(m, " IALSTrainer" )
97+ .def (nanobind ::init<IALSModelConfig, const SparseMatrix &>())
10798 .def (" step" , &IALSTrainer::step)
10899 .def (" user_scores" , &IALSTrainer::user_scores)
109100 .def (" transform_user" , &IALSTrainer::transform_user)
110101 .def (" transform_item" , &IALSTrainer::transform_item)
111102 .def (" compute_loss" , &IALSTrainer::compute_loss)
112- .def_readwrite (" user" , &IALSTrainer::user)
113- .def_readwrite (" item" , &IALSTrainer::item)
114- .def (py::pickle (
115- [](const IALSTrainer &trainer) {
116- return py::make_tuple (trainer.config_ , trainer.user , trainer.item );
117- },
118- [](py::tuple t) {
119- if (t.size () != 3 )
120- throw std::runtime_error (" Invalid state!" );
121- IALSTrainer trainer (t[0 ].cast <IALSModelConfig>(),
122- t[1 ].cast <DenseMatrix>(),
123- t[2 ].cast <DenseMatrix>());
124- return trainer;
125- }));
103+ .def_rw (" user" , &IALSTrainer::user)
104+ .def_rw (" item" , &IALSTrainer::item)
105+ .def (" __getstate__" , [](const IALSTrainer & trainer) {
106+ return std::make_tuple (trainer.config_ , trainer.user , trainer.item );
107+ })
108+ .def (" __setstate__" , [](IALSTrainer & trainer, const std::tuple<IALSModelConfig, DenseMatrix, DenseMatrix> & state) {
109+ new (&trainer) IALSTrainer (
110+ std::get<0 >(state), std::get<1 >(state),
111+ std::get<2 >(state)
112+
113+ );
114+ });
126115}
0 commit comments