diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 68de839376647..213ddc14b63a9 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -545,15 +545,14 @@ graph_impl::add(std::shared_ptr &DynCGImpl, return NodeImpl; } -std::shared_ptr graph_impl::getQueue() const { - std::shared_ptr Return{}; - if (!MRecordingQueues.empty()) - Return = MRecordingQueues.begin()->lock(); - return Return; +std::shared_ptr +graph_impl::getLastRecordedQueue() const { + return MLastRecordedQueue.lock(); } void graph_impl::addQueue(sycl::detail::queue_impl &RecordingQueue) { - MRecordingQueues.insert(RecordingQueue.weak_from_this()); + MLastRecordedQueue = RecordingQueue.weak_from_this(); + MRecordingQueues.insert(MLastRecordedQueue); } void graph_impl::removeQueue(sycl::detail::queue_impl &RecordingQueue) { @@ -932,7 +931,7 @@ exec_graph_impl::exec_graph_impl(sycl::context Context, // Copy nodes from GraphImpl and merge any subgraph nodes into this graph. duplicateNodes(); - if (auto PlaceholderQueuePtr = GraphImpl->getQueue()) { + if (auto PlaceholderQueuePtr = GraphImpl->getLastRecordedQueue()) { MQueueImpl = std::move(PlaceholderQueuePtr); } else { MQueueImpl = sycl::detail::queue_impl::create( diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index e6637032676cb..7f9e052fea20e 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -176,7 +176,9 @@ class graph_impl : public std::enable_shared_from_this { node_impl &add(std::shared_ptr &DynCGImpl, nodes_range Deps); - std::shared_ptr getQueue() const; + /// Get queue that was last recorded from. + /// @ return Queue that started last recording into associated graph. + std::shared_ptr getLastRecordedQueue() const; /// Add a queue to the set of queues which are currently recording to this /// graph. @@ -558,6 +560,8 @@ class graph_impl : public std::enable_shared_from_this { std::owner_less>>; /// Unique set of queues which are currently recording to this graph. RecQueuesStorage MRecordingQueues; + /// Queue that has been last recorded from. + std::weak_ptr MLastRecordedQueue; /// Map of events to their associated recorded nodes. std::unordered_map, node_impl *> MEventsMap; diff --git a/sycl/unittests/Extensions/CommandGraph/Common.hpp b/sycl/unittests/Extensions/CommandGraph/Common.hpp index aa7c59472105e..55dadfbcadbdf 100644 --- a/sycl/unittests/Extensions/CommandGraph/Common.hpp +++ b/sycl/unittests/Extensions/CommandGraph/Common.hpp @@ -46,6 +46,10 @@ class GraphImplTest { static int NumSyncPoints(const exec_graph_impl &Impl) { return Impl.MSyncPoints.size(); } + static std::shared_ptr + GetQueueImpl(const exec_graph_impl &Impl) { + return Impl.MQueueImpl; + } }; // Common Test fixture diff --git a/sycl/unittests/Extensions/CommandGraph/Regressions.cpp b/sycl/unittests/Extensions/CommandGraph/Regressions.cpp index 0833357f14357..26f1ef70e60e8 100644 --- a/sycl/unittests/Extensions/CommandGraph/Regressions.cpp +++ b/sycl/unittests/Extensions/CommandGraph/Regressions.cpp @@ -94,3 +94,39 @@ TEST_F(CommandGraphTest, QueueRecordBarrierMultipleGraph) { Queue.ext_oneapi_submit_barrier(); GraphC.end_recording(Queue); } + +// Test that the last recorded queue is preserved after cleanup. +// This is a regression test for a bug where getLastRecordedQueue() would +// return nullptr after the recording queues were cleaned up, because the +// previous implementation (getQueue()) looked in the MRecordingQueues set +// which gets cleared on end_recording(). The fix introduces MLastRecordedQueue +// which persists even after cleanup, allowing the executable graph to retrieve +// the queue that was used for recording. +TEST_F(CommandGraphTest, LastRecordedQueueAfterCleanup) { + // Record some work to the graph + Graph.begin_recording(Queue); + Queue.submit( + [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); + Graph.end_recording(Queue); + + // Get the graph implementation to check internal state + auto GraphImpl = getSyclObjImpl(Graph); + + // getLastRecordedQueue() should return the queue that was used for recording + // even after end_recording() has cleared the recording queues + auto LastQueue = GraphImpl->getLastRecordedQueue(); + EXPECT_NE(LastQueue, nullptr); + EXPECT_EQ(LastQueue, getSyclObjImpl(Queue)); + + // Finalize the graph - this uses getLastRecordedQueue() internally + // to set up the executable graph's queue. Before the fix, this could fail + // if getLastRecordedQueue() returned nullptr. + auto GraphExec = Graph.finalize(); + experimental::detail::exec_graph_impl &ExecGraphImpl = + *getSyclObjImpl(GraphExec); + + // The executable graph should have the queue from recording + auto ExecQueueImpl = GraphImplTest::GetQueueImpl(ExecGraphImpl); + EXPECT_NE(ExecQueueImpl, nullptr); + EXPECT_EQ(ExecQueueImpl, getSyclObjImpl(Queue)); +}