Skip to content

Commit be45d64

Browse files
authored
Merge branch 'main' into stabilize-queue-unit-tests
2 parents 144eb00 + b1246b8 commit be45d64

File tree

5 files changed

+502
-36
lines changed

5 files changed

+502
-36
lines changed

Include/XTaskQueue.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,27 @@ STDAPI_(void) XTaskQueueCloseHandle(
185185
/// preventing new items from being queued. Once a queue is terminated
186186
/// its handle can be closed. New items cannot be enqueued to a task
187187
/// queue that has been terminated.
188+
///
189+
/// Each task queue terminates independently. For composite queues created
190+
/// with XTaskQueueCreateComposite, terminating a composite queue does NOT
191+
/// automatically terminate other queues sharing the same ports.
192+
///
193+
/// When wait=true:
194+
/// - Blocks until this queue's termination callback completes
195+
/// - Does NOT wait for other independent queues (including composite delegates)
196+
/// - Ensures this queue's termination callback has finished executing
197+
/// - Safe to unload code/modules after this returns
198+
///
199+
/// When wait=false:
200+
/// - Returns immediately after initiating termination
201+
/// - The termination callback will be invoked asynchronously when termination completes
202+
///
203+
/// The termination callback is invoked after all work and completion callbacks
204+
/// have been canceled or completed. After the termination callback returns, the
205+
/// implementation will no longer access the queue's internal state.
206+
///
207+
/// For coordinated shutdown of multiple queues sharing ports, use termination
208+
/// callbacks to track completion of each queue before performing final cleanup.
188209
/// </summary>
189210
/// <param name='queue'>The queue to terminate.</param>
190211
/// <param name='wait'>True to wait for the termination to complete.</param>

Source/Task/TaskQueue.cpp

Lines changed: 90 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,10 @@ HRESULT __stdcall TaskQueuePortImpl::PrepareTerminate(
518518
std::unique_ptr<TerminationEntry> term(new (std::nothrow) TerminationEntry);
519519
RETURN_IF_NULL_ALLOC(term);
520520

521-
RETURN_HR_IF(E_OUTOFMEMORY, !m_terminationList->reserve_node(term->node));
521+
{
522+
std::lock_guard<std::mutex> lock(m_terminationLock);
523+
RETURN_HR_IF(E_OUTOFMEMORY, !m_terminationList->reserve_node(term->node));
524+
}
522525

523526
term->callbackContext = callbackContext;
524527
term->callback = callback;
@@ -540,6 +543,7 @@ void __stdcall TaskQueuePortImpl::CancelTermination(
540543

541544
if (term->node != 0)
542545
{
546+
std::lock_guard<std::mutex> lock(m_terminationLock);
543547
m_terminationList->free_node(term->node);
544548
}
545549

@@ -565,6 +569,7 @@ void __stdcall TaskQueuePortImpl::Terminate(
565569
}
566570
else
567571
{
572+
std::lock_guard<std::mutex> lock(m_terminationLock);
568573
m_pendingTerminationList->push_back(term, term->node);
569574
term->node = 0;
570575
}
@@ -687,7 +692,7 @@ bool TaskQueuePortImpl::Wait(
687692
_In_ uint32_t timeout)
688693
{
689694
#ifdef _WIN32
690-
while (m_suspended || (m_queueList->empty() && m_terminationList->empty()))
695+
while (m_suspended || (m_queueList->empty() && TerminationListEmpty()))
691696
{
692697
if (portContext->GetStatus() == TaskQueuePortStatus::Terminated)
693698
{
@@ -749,7 +754,7 @@ bool TaskQueuePortImpl::Wait(
749754
}
750755

751756
#else
752-
while (m_suspended || (m_queueList->empty() && m_terminationList->empty()))
757+
while (m_suspended || (m_queueList->empty() && TerminationListEmpty()))
753758
{
754759
if (portContext->GetStatus() == TaskQueuePortStatus::Terminated)
755760
{
@@ -825,20 +830,31 @@ void __stdcall TaskQueuePortImpl::ResumeTermination(
825830
{
826831
// Removed the last external callback. Look for
827832
// parked terminations and reschedule them.
828-
829-
m_pendingTerminationList->remove_if([&](auto& entry, auto address)
833+
// Use a temporary list sharing the same heap to avoid allocation
834+
LocklessQueue<TerminationEntry*> entries_to_schedule(*m_pendingTerminationList.get());
835+
830836
{
831-
if (entry->portContext == portContext)
837+
std::lock_guard<std::mutex> lock(m_terminationLock);
838+
m_pendingTerminationList->remove_if([&](auto& entry, auto address)
832839
{
833-
// This entry is for the port that's resuming,
834-
// we can schedule it.
835-
entry->node = address;
836-
ScheduleTermination(entry);
837-
return true;
838-
}
840+
if (entry->portContext == portContext)
841+
{
842+
entries_to_schedule.push_back(entry, address);
843+
return true;
844+
}
839845

840-
return false;
841-
});
846+
return false;
847+
});
848+
}
849+
850+
// Schedule entries outside the lock
851+
TerminationEntry* entry;
852+
uint64_t address;
853+
while (entries_to_schedule.pop_front(entry, address))
854+
{
855+
entry->node = address;
856+
ScheduleTermination(entry);
857+
}
842858
}
843859
}
844860

@@ -872,18 +888,21 @@ void __stdcall TaskQueuePortImpl::ResumePort()
872888
m_queueList->push_back(std::move(queueEntry), address);
873889
}
874890

875-
TerminationEntry* terminationEntry;
876-
LocklessQueue<TerminationEntry*> retainTerminations(*(m_terminationList.get()));
877-
878-
while (m_terminationList->pop_front(terminationEntry, address))
879891
{
880-
notifyCount++;
881-
retainTerminations.push_back(std::move(terminationEntry), address);
882-
}
892+
std::lock_guard<std::mutex> lock(m_terminationLock);
893+
TerminationEntry* terminationEntry;
894+
LocklessQueue<TerminationEntry*> retainTerminations(*(m_terminationList.get()));
883895

884-
while (retainTerminations.pop_front(terminationEntry, address))
885-
{
886-
m_terminationList->push_back(std::move(terminationEntry), address);
896+
while (m_terminationList->pop_front(terminationEntry, address))
897+
{
898+
notifyCount++;
899+
retainTerminations.push_back(std::move(terminationEntry), address);
900+
}
901+
902+
while (retainTerminations.pop_front(terminationEntry, address))
903+
{
904+
m_terminationList->push_back(std::move(terminationEntry), address);
905+
}
887906
}
888907

889908
m_suspended = false;
@@ -1199,19 +1218,46 @@ void TaskQueuePortImpl::NotifyItemQueued()
11991218

12001219
void TaskQueuePortImpl::SignalTerminations()
12011220
{
1202-
m_terminationList->remove_if([this](auto& entry, auto address)
1221+
// Collect entries to process outside the iteration to avoid concurrent modification races
1222+
// when callbacks invoke nested Terminate() calls.
1223+
// Use a temporary list sharing the same heap to avoid allocation.
1224+
LocklessQueue<TerminationEntry*> entries_to_process(*m_terminationList.get());
1225+
12031226
{
1204-
if (entry->portContext->GetStatus() >= TaskQueuePortStatus::Terminating)
1227+
std::lock_guard<std::mutex> lock(m_terminationLock);
1228+
m_terminationList->remove_if([this, &entries_to_process](auto& entry, auto address)
12051229
{
1206-
entry->portContext->SetStatus(TaskQueuePortStatus::Terminated);
1207-
entry->callback(entry->callbackContext);
1230+
if (entry->portContext->GetStatus() >= TaskQueuePortStatus::Terminating)
1231+
{
1232+
entry->portContext->SetStatus(TaskQueuePortStatus::Terminated);
1233+
entries_to_process.push_back(entry, address);
1234+
return true;
1235+
}
1236+
1237+
return false;
1238+
});
1239+
}
1240+
1241+
// Now process callbacks outside the remove_if iteration
1242+
// This prevents races when callbacks invoke nested operations like Terminate()
1243+
TerminationEntry* entry;
1244+
uint64_t address;
1245+
while (entries_to_process.pop_front(entry, address))
1246+
{
1247+
// AddRef portContext to prevent UAF if callback releases the queue
1248+
entry->portContext->AddRef();
1249+
1250+
entry->callback(entry->callbackContext);
1251+
1252+
// Release portContext after callback completes
1253+
entry->portContext->Release();
1254+
1255+
{
1256+
std::lock_guard<std::mutex> lock(m_terminationLock);
12081257
m_terminationList->free_node(address);
1209-
delete entry;
1210-
return true;
12111258
}
1212-
1213-
return false;
1214-
});
1259+
delete entry;
1260+
}
12151261
}
12161262

12171263
void TaskQueuePortImpl::ScheduleTermination(
@@ -1225,8 +1271,11 @@ void TaskQueuePortImpl::ScheduleTermination(
12251271

12261272
// This never fails because we preallocate the
12271273
// list node.
1228-
m_terminationList->push_back(term, term->node);
1229-
term->node = 0; // now owned by the list
1274+
{
1275+
std::lock_guard<std::mutex> lock(m_terminationLock);
1276+
m_terminationList->push_back(term, term->node);
1277+
term->node = 0; // now owned by the list
1278+
}
12301279

12311280
// The port should have already been marked as terminated, so now
12321281
// we can signal it to wake up. This should drain pending calls and
@@ -1236,6 +1285,12 @@ void TaskQueuePortImpl::ScheduleTermination(
12361285
NotifyItemQueued();
12371286
}
12381287

1288+
bool TaskQueuePortImpl::TerminationListEmpty()
1289+
{
1290+
std::lock_guard<std::mutex> lock(m_terminationLock);
1291+
return m_terminationList->empty();
1292+
}
1293+
12391294
#ifdef _WIN32
12401295
void CALLBACK TaskQueuePortImpl::WaitCallback(
12411296
_In_ PTP_CALLBACK_INSTANCE instance,

Source/Task/TaskQueueImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class TaskQueuePortImpl: public Api<ApiId::TaskQueuePort, ITaskQueuePort>
262262
std::unique_ptr<LocklessQueue<QueueEntry>> m_pendingList;
263263
std::unique_ptr<LocklessQueue<TerminationEntry*>> m_terminationList;
264264
std::unique_ptr<LocklessQueue<TerminationEntry*>> m_pendingTerminationList;
265+
std::mutex m_terminationLock;
265266
OS::WaitTimer m_timer;
266267
OS::ThreadPool m_threadPool;
267268
std::atomic<uint64_t> m_timerDue = { UINT64_MAX };
@@ -312,6 +313,7 @@ class TaskQueuePortImpl: public Api<ApiId::TaskQueuePort, ITaskQueuePort>
312313

313314
void SignalTerminations();
314315
void ScheduleTermination(_In_ TerminationEntry* term);
316+
bool TerminationListEmpty();
315317

316318
void SignalQueue();
317319
void NotifyItemQueued();

0 commit comments

Comments
 (0)