Skip to content

Commit 4a8ef52

Browse files
committed
add back ManagedMemory and keep_args_alive
test slowdown was likely caused by their removal
1 parent ffa9268 commit 4a8ef52

1 file changed

Lines changed: 107 additions & 2 deletions

File tree

dpnp/backend/include/dpnp4pybind11.hpp

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,114 @@ class usm_ndarray : public py::object
553553

554554
namespace utils
555555
{
556+
namespace detail
557+
{
558+
struct ManagedMemory
559+
{
560+
// TODO: do we need to check for memory here? Or can we assume only
561+
// dpnp::tensor::usm_ndarray will be passed?
562+
static bool is_usm_managed_by_shared_ptr(const py::object &h)
563+
{
564+
565+
if (py::isinstance<dpctl::memory::usm_memory>(h)) {
566+
const auto &usm_memory_inst =
567+
py::cast<dpctl::memory::usm_memory>(h);
568+
return usm_memory_inst.is_managed_by_smart_ptr();
569+
}
570+
else if (py::isinstance<dpnp::tensor::usm_ndarray>(h)) {
571+
const auto &usm_array_inst = py::cast<dpnp::tensor::usm_ndarray>(h);
572+
return usm_array_inst.is_managed_by_smart_ptr();
573+
}
574+
575+
return false;
576+
}
577+
578+
static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
579+
{
580+
if (py::isinstance<dpctl::memory::usm_memory>(h)) {
581+
const auto &usm_memory_inst =
582+
py::cast<dpctl::memory::usm_memory>(h);
583+
return usm_memory_inst.get_smart_ptr_owner();
584+
}
585+
else if (py::isinstance<dpnp::tensor::usm_ndarray>(h)) {
586+
const auto &usm_array_inst = py::cast<dpnp::tensor::usm_ndarray>(h);
587+
return usm_array_inst.get_smart_ptr_owner();
588+
}
589+
590+
throw std::runtime_error(
591+
"Attempted extraction of shared_ptr on an unrecognized type");
592+
}
593+
};
594+
} // end of namespace detail
595+
596+
template <std::size_t num>
597+
sycl::event keep_args_alive(sycl::queue &q,
598+
const py::object (&py_objs)[num],
599+
const std::vector<sycl::event> &depends = {})
600+
{
601+
std::size_t n_objects_held = 0;
602+
std::array<std::shared_ptr<py::handle>, num> shp_arr{};
603+
604+
std::size_t n_usm_owners_held = 0;
605+
std::array<std::shared_ptr<void>, num> shp_usm{};
606+
607+
for (std::size_t i = 0; i < num; ++i) {
608+
const auto &py_obj_i = py_objs[i];
609+
if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
610+
const auto &shp =
611+
detail::ManagedMemory::extract_shared_ptr(py_obj_i);
612+
shp_usm[n_usm_owners_held] = shp;
613+
++n_usm_owners_held;
614+
}
615+
else {
616+
shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
617+
shp_arr[n_objects_held]->inc_ref();
618+
++n_objects_held;
619+
}
620+
}
621+
622+
bool use_depends = true;
623+
sycl::event host_task_ev;
624+
625+
if (n_usm_owners_held > 0) {
626+
host_task_ev = q.submit([&](sycl::handler &cgh) {
627+
if (use_depends) {
628+
cgh.depends_on(depends);
629+
use_depends = false;
630+
}
631+
else {
632+
cgh.depends_on(host_task_ev);
633+
}
634+
cgh.host_task([shp_usm = std::move(shp_usm)]() {
635+
// no body, but shared pointers are captured in
636+
// the lambda, ensuring that USM allocation is
637+
// kept alive
638+
});
639+
});
640+
}
641+
642+
if (n_objects_held > 0) {
643+
host_task_ev = q.submit([&](sycl::handler &cgh) {
644+
if (use_depends) {
645+
cgh.depends_on(depends);
646+
use_depends = false;
647+
}
648+
else {
649+
cgh.depends_on(host_task_ev);
650+
}
651+
cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
652+
py::gil_scoped_acquire acquire;
653+
654+
for (std::size_t i = 0; i < n_objects_held; ++i) {
655+
shp_arr[i]->dec_ref();
656+
}
657+
});
658+
});
659+
}
660+
661+
return host_task_ev;
662+
}
556663

557-
// add these functions to dpnp::utils for convenience
558-
using ::dpctl::utils::keep_args_alive;
559664
using ::dpctl::utils::queues_are_compatible;
560665

561666
/*! @brief Check if all allocation queues of usm_ndarrays are the same as

0 commit comments

Comments
 (0)