@@ -553,9 +553,114 @@ class usm_ndarray : public py::object
553553
554554namespace utils
555555{
556+ namespace detail
557+ {
558+ struct ManagedMemory
559+ {
560+ // TODO: do we need to check for memory here? Or can we assume only
561+ // dpnp::tensor::usm_ndarray will be passed?
562+ static bool is_usm_managed_by_shared_ptr (const py::object &h)
563+ {
564+
565+ if (py::isinstance<dpctl::memory::usm_memory>(h)) {
566+ const auto &usm_memory_inst =
567+ py::cast<dpctl::memory::usm_memory>(h);
568+ return usm_memory_inst.is_managed_by_smart_ptr ();
569+ }
570+ else if (py::isinstance<dpnp::tensor::usm_ndarray>(h)) {
571+ const auto &usm_array_inst = py::cast<dpnp::tensor::usm_ndarray>(h);
572+ return usm_array_inst.is_managed_by_smart_ptr ();
573+ }
574+
575+ return false ;
576+ }
577+
578+ static const std::shared_ptr<void > &extract_shared_ptr (const py::object &h)
579+ {
580+ if (py::isinstance<dpctl::memory::usm_memory>(h)) {
581+ const auto &usm_memory_inst =
582+ py::cast<dpctl::memory::usm_memory>(h);
583+ return usm_memory_inst.get_smart_ptr_owner ();
584+ }
585+ else if (py::isinstance<dpnp::tensor::usm_ndarray>(h)) {
586+ const auto &usm_array_inst = py::cast<dpnp::tensor::usm_ndarray>(h);
587+ return usm_array_inst.get_smart_ptr_owner ();
588+ }
589+
590+ throw std::runtime_error (
591+ " Attempted extraction of shared_ptr on an unrecognized type" );
592+ }
593+ };
594+ } // end of namespace detail
595+
596+ template <std::size_t num>
597+ sycl::event keep_args_alive (sycl::queue &q,
598+ const py::object (&py_objs)[num],
599+ const std::vector<sycl::event> &depends = {})
600+ {
601+ std::size_t n_objects_held = 0 ;
602+ std::array<std::shared_ptr<py::handle>, num> shp_arr{};
603+
604+ std::size_t n_usm_owners_held = 0 ;
605+ std::array<std::shared_ptr<void >, num> shp_usm{};
606+
607+ for (std::size_t i = 0 ; i < num; ++i) {
608+ const auto &py_obj_i = py_objs[i];
609+ if (detail::ManagedMemory::is_usm_managed_by_shared_ptr (py_obj_i)) {
610+ const auto &shp =
611+ detail::ManagedMemory::extract_shared_ptr (py_obj_i);
612+ shp_usm[n_usm_owners_held] = shp;
613+ ++n_usm_owners_held;
614+ }
615+ else {
616+ shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
617+ shp_arr[n_objects_held]->inc_ref ();
618+ ++n_objects_held;
619+ }
620+ }
621+
622+ bool use_depends = true ;
623+ sycl::event host_task_ev;
624+
625+ if (n_usm_owners_held > 0 ) {
626+ host_task_ev = q.submit ([&](sycl::handler &cgh) {
627+ if (use_depends) {
628+ cgh.depends_on (depends);
629+ use_depends = false ;
630+ }
631+ else {
632+ cgh.depends_on (host_task_ev);
633+ }
634+ cgh.host_task ([shp_usm = std::move (shp_usm)]() {
635+ // no body, but shared pointers are captured in
636+ // the lambda, ensuring that USM allocation is
637+ // kept alive
638+ });
639+ });
640+ }
641+
642+ if (n_objects_held > 0 ) {
643+ host_task_ev = q.submit ([&](sycl::handler &cgh) {
644+ if (use_depends) {
645+ cgh.depends_on (depends);
646+ use_depends = false ;
647+ }
648+ else {
649+ cgh.depends_on (host_task_ev);
650+ }
651+ cgh.host_task ([n_objects_held, shp_arr = std::move (shp_arr)]() {
652+ py::gil_scoped_acquire acquire;
653+
654+ for (std::size_t i = 0 ; i < n_objects_held; ++i) {
655+ shp_arr[i]->dec_ref ();
656+ }
657+ });
658+ });
659+ }
660+
661+ return host_task_ev;
662+ }
556663
557- // add these functions to dpnp::utils for convenience
558- using ::dpctl::utils::keep_args_alive;
559664using ::dpctl::utils::queues_are_compatible;
560665
561666/* ! @brief Check if all allocation queues of usm_ndarrays are the same as
0 commit comments