Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,15 +545,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
return NodeImpl;
}

std::shared_ptr<sycl::detail::queue_impl> graph_impl::getQueue() const {
std::shared_ptr<sycl::detail::queue_impl> Return{};
if (!MRecordingQueues.empty())
Return = MRecordingQueues.begin()->lock();
return Return;
std::shared_ptr<sycl::detail::queue_impl>
graph_impl::getLastRecordedQueue() const {
return MLastRecordedQueue.lock();
}

void graph_impl::addQueue(sycl::detail::queue_impl &RecordingQueue) {
MRecordingQueues.insert(RecordingQueue.weak_from_this());
MLastRecordedQueue = RecordingQueue.weak_from_this();
MRecordingQueues.insert(MLastRecordedQueue);
}

void graph_impl::removeQueue(sycl::detail::queue_impl &RecordingQueue) {
Expand Down Expand Up @@ -932,7 +931,7 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
// Copy nodes from GraphImpl and merge any subgraph nodes into this graph.
duplicateNodes();

if (auto PlaceholderQueuePtr = GraphImpl->getQueue()) {
if (auto PlaceholderQueuePtr = GraphImpl->getLastRecordedQueue()) {
MQueueImpl = std::move(PlaceholderQueuePtr);
} else {
MQueueImpl = sycl::detail::queue_impl::create(
Expand Down
6 changes: 5 additions & 1 deletion sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
node_impl &add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
nodes_range Deps);

std::shared_ptr<sycl::detail::queue_impl> getQueue() const;
/// Get queue that was last recorded from.
/// @ return Queue that started last recording into associated graph.
std::shared_ptr<sycl::detail::queue_impl> getLastRecordedQueue() const;

/// Add a queue to the set of queues which are currently recording to this
/// graph.
Expand Down Expand Up @@ -558,6 +560,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>;
/// Unique set of queues which are currently recording to this graph.
RecQueuesStorage MRecordingQueues;
/// Queue that has been last recorded from.
std::weak_ptr<sycl::detail::queue_impl> MLastRecordedQueue;
/// Map of events to their associated recorded nodes.
std::unordered_map<std::shared_ptr<sycl::detail::event_impl>, node_impl *>
MEventsMap;
Expand Down
4 changes: 4 additions & 0 deletions sycl/unittests/Extensions/CommandGraph/Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class GraphImplTest {
static int NumSyncPoints(const exec_graph_impl &Impl) {
return Impl.MSyncPoints.size();
}
static std::shared_ptr<sycl::detail::queue_impl>
GetQueueImpl(const exec_graph_impl &Impl) {
return Impl.MQueueImpl;
}
};

// Common Test fixture
Expand Down
36 changes: 36 additions & 0 deletions sycl/unittests/Extensions/CommandGraph/Regressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,39 @@ TEST_F(CommandGraphTest, QueueRecordBarrierMultipleGraph) {
Queue.ext_oneapi_submit_barrier();
GraphC.end_recording(Queue);
}

// Test that the last recorded queue is preserved after cleanup.
// This is a regression test for a bug where getLastRecordedQueue() would
// return nullptr after the recording queues were cleaned up, because the
// previous implementation (getQueue()) looked in the MRecordingQueues set
// which gets cleared on end_recording(). The fix introduces MLastRecordedQueue
// which persists even after cleanup, allowing the executable graph to retrieve
// the queue that was used for recording.
TEST_F(CommandGraphTest, LastRecordedQueueAfterCleanup) {
// Record some work to the graph
Graph.begin_recording(Queue);
Queue.submit(
[&](sycl::handler &cgh) { cgh.single_task<TestKernel>([]() {}); });
Graph.end_recording(Queue);

// Get the graph implementation to check internal state
auto GraphImpl = getSyclObjImpl(Graph);

// getLastRecordedQueue() should return the queue that was used for recording
// even after end_recording() has cleared the recording queues
auto LastQueue = GraphImpl->getLastRecordedQueue();
EXPECT_NE(LastQueue, nullptr);
EXPECT_EQ(LastQueue, getSyclObjImpl(Queue));

// Finalize the graph - this uses getLastRecordedQueue() internally
// to set up the executable graph's queue. Before the fix, this could fail
// if getLastRecordedQueue() returned nullptr.
auto GraphExec = Graph.finalize();
experimental::detail::exec_graph_impl &ExecGraphImpl =
*getSyclObjImpl(GraphExec);

// The executable graph should have the queue from recording
auto ExecQueueImpl = GraphImplTest::GetQueueImpl(ExecGraphImpl);
EXPECT_NE(ExecQueueImpl, nullptr);
EXPECT_EQ(ExecQueueImpl, getSyclObjImpl(Queue));
}
Loading