Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 171 additions & 30 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -339,6 +339,10 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
}
setContextState(llmRequest);
auto future = mCacheSender->sendAsync(*llmRequest);
TLLM_LOG_DEBUG("respondAndSendAsync: adding request %ld to mSenderFutures (ptr=%p, transferStart=%ld, size=%zu)",
llmRequest->mRequestId, static_cast<void*>(llmRequest),
static_cast<long>(llmRequest->getKvCacheTransferStart().time_since_epoch().count()),
mSenderFutures.size() + 1);
mSenderFutures.emplace_back(llmRequest, std::move(future));
}

Expand Down Expand Up @@ -485,10 +489,35 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus(
{
bool blockAll = !atLeastRequestNum.has_value();
std::optional<int> senderFutureTimeoutMs = std::nullopt;
// If blockAll is true, we want to block and not use a timeout
if (!blockAll && mCacheTransceiverConfig.has_value())
std::optional<int> kvTransferTimeoutMs = std::nullopt;
// Always use a bounded timeout to prevent unbounded blocking.
// The caller (scheduler) loops, so timed-out transfers retry on next iteration.
if (mCacheTransceiverConfig.has_value())
{
senderFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs();
kvTransferTimeoutMs = mCacheTransceiverConfig->getKvTransferTimeoutMs();
}
{
senderFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs();
kvTransferTimeoutMs = mCacheTransceiverConfig->getKvTransferTimeoutMs();
}

// Log mSenderFutures state for diagnosing dangling pointer issues.
// Each entry's pointer address and request ID are logged so we can detect
// when a pointer's underlying memory is freed (reqId changes to 0).
if (!mSenderFutures.empty())
{
TLLM_LOG_DEBUG("checkContextTransferStatus: mSenderFutures.size()=%zu, blockAll=%d, "
"kvTransferTimeoutMs=%d",
mSenderFutures.size(), blockAll ? 1 : 0,
kvTransferTimeoutMs.value_or(-1));
for (size_t i = 0; i < mSenderFutures.size(); ++i)
{
auto& [req, fut] = mSenderFutures[i];
auto startTs = req->getKvCacheTransferStart().time_since_epoch().count();
TLLM_LOG_DEBUG(" [%zu] ptr=%p reqId=%ld startTs=%ld",
i, static_cast<void const*>(req), req->mRequestId, static_cast<long>(startTs));
}
}

auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupTPInDPComm : mGroupTensorParaComm;
Expand Down Expand Up @@ -551,8 +580,9 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus(
try
{
// Wait for up to a specified timeout
auto status = future.wait_for(std::chrono::milliseconds(senderFutureTimeoutMs.value_or(0)));
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
auto const timeoutMs = senderFutureTimeoutMs.value_or(1000);
auto status = future.wait_for(std::chrono::milliseconds(timeoutMs));
if (status == std::future_status::ready)
{
future.get();
requestsStatus.completedRequestIds.insert(request->mRequestId);
Expand All @@ -564,8 +594,59 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus(
}
else if (status == std::future_status::timeout)
{
// Check if total elapsed time exceeds kv_transfer_timeout_ms.
// Without this, stuck transfers retry the per-iteration timeout forever,
// holding KV blocks indefinitely and exhausting the cache pool.
if (kvTransferTimeoutMs.has_value())
{
auto transferStart = request->getKvCacheTransferStart();
// Guard: if transfer start was never set (TimePoint epoch),
// the request pointer may be stale or the start time was not recorded.
// Treat as timed out immediately to avoid infinite retry.
bool startTimeValid = transferStart.time_since_epoch().count() > 0;
bool shouldTimeout = !startTimeValid;
long elapsedMs = 0;
if (startTimeValid)
{
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
LlmRequest::getSteadyClockNow() - transferStart);
elapsedMs = static_cast<long>(elapsed.count());
shouldTimeout = elapsedMs > kvTransferTimeoutMs.value();
}
if (shouldTimeout)
{
if (startTimeValid)
{
TLLM_LOG_ERROR(
"Context KV cache transfer for request %ld exceeded total timeout: "
"elapsed %ld ms > limit %d ms. Marking as error.",
request->mRequestId, elapsedMs, kvTransferTimeoutMs.value());
try
{
mCacheSender->cancelRequest(*request);
}
catch (...)
{
}
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
}
else
{
// Start time is epoch — the LlmRequest* is likely a dangling pointer
// (request was freed while still in mSenderFutures). Do NOT dereference
// request beyond this point. Just remove the stale entry.
TLLM_LOG_WARNING(
"Removing stale entry from mSenderFutures: transfer start time is "
"uninitialized (request pointer %p may be dangling).",
static_cast<void const*>(request));
}
it = mSenderFutures.erase(it);
continue;
}
}
TLLM_LOG_WARNING("Timed out waiting for context KV cache transfer after %d milliseconds.",
senderFutureTimeoutMs.value());
timeoutMs);
++it;
}
else
Expand All @@ -580,10 +661,26 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus(
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR(
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
// Guard: the request pointer may be stale if the sender thread crashed
// and the request was freed concurrently. Check transfer start time as
// a heuristic — epoch (0) indicates likely dangling pointer.
auto transferStart = request->getKvCacheTransferStart();
bool likelyValid = transferStart.time_since_epoch().count() > 0;
if (likelyValid)
{
TLLM_LOG_ERROR(
"Error occurred during context transfer for request %ld: %s",
request->mRequestId, e.what());
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
}
else
{
TLLM_LOG_WARNING(
"Error during context transfer with likely stale request pointer %p: %s. "
"Removing entry without setting state.",
static_cast<void const*>(request), e.what());
}
it = mSenderFutures.erase(it);
}
}
Expand All @@ -593,12 +690,29 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus(
}
}

if (!requestsStatus.completedRequestIds.empty() || !requestsStatus.errorRequestIds.empty())
{
TLLM_LOG_DEBUG("checkContextTransferStatus done: completed=%zu, errors=%zu, "
"mSenderFutures.size()=%zu",
requestsStatus.completedRequestIds.size(),
requestsStatus.errorRequestIds.size(),
mSenderFutures.size());
}

return requestsStatus;
}

void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)
{
bool blockAll = !atLeastRequestNum.has_value();
std::optional<int> receiverFutureTimeoutMs = std::nullopt;
// Always use a bounded timeout to prevent unbounded blocking.
// The caller (scheduler) loops, so timed-out transfers retry on next iteration.
if (mCacheTransceiverConfig.has_value())
{
receiverFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs();
}

std::vector<LlmRequest::RequestIdType> genTransferReadyRequestIds;
for (auto&& [request, future] : mRequesterFutures)
{
Expand Down Expand Up @@ -709,41 +823,68 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
atLeastRequestNum.value_or(0));
}
auto const syncSize = (syncComm != nullptr) ? syncComm->getSize() : 1;
for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();)
{
if (blockAll || toCompleteIdSet.find(it->first->mRequestId) != toCompleteIdSet.end())
{
try
{
it->second.get();
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);

// Gather the kv cache transfer time from all workers and update to leader rank
if (!common::getEnvKVCacheTimeOutputPath().empty())
// Wait for up to a specified timeout
auto const timeoutMs = receiverFutureTimeoutMs.value_or(1000);
auto status = it->second.wait_for(std::chrono::milliseconds(timeoutMs));
if (status == std::future_status::ready)
{
it->second.get();
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);

// Gather the kv cache transfer time from all workers and update to leader rank.
// Only call the timing collective when either all ranks block together (blockAll)
// or the request was confirmed ready on every rank in the initial poll, to avoid
// hanging in allgather when a peer timed out and skipped this request.
if (!common::getEnvKVCacheTimeOutputPath().empty())
{
auto const freqIt = frequencyMap.find(it->first->mRequestId);
if (blockAll || (freqIt != frequencyMap.end() && freqIt->second == syncSize))
{
updateKVCacheTransferBW(syncComm, it->first);
}
}
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
it = mRequesterFutures.erase(it);
}
else if (status == std::future_status::timeout)
{
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
updateKVCacheTransferBW(syncComm, it->first);
TLLM_LOG_WARNING(
"Timed out waiting for generation KV cache transfer after %d milliseconds.", timeoutMs);
++it;
}
else
{
TLLM_LOG_ERROR("Future returned unexpected status for request %ld. Marking as error",
it->first->mRequestId);
it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
it = mRequesterFutures.erase(it);
}
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR(
"Error occurred during generation transfer for request %ld: %s", it->first->mRequestId, e.what());
it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
it = mRequesterFutures.erase(it);
}
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
it = mRequesterFutures.erase(it);
}
else
{
Expand Down
8 changes: 8 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1686,6 +1686,14 @@ std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::sto
}
if (pinBlocks)
{
// Claim block from free queue before pinning so that
// unpinBlocksById can safely releaseBlock it back later.
// Without this, the block stays in the free queue while pinned,
// and the subsequent releaseBlock creates a duplicate entry.
if (!searchRoot->hasRefs())
{
mEvictionPolicy->claimBlock(searchRoot, searchRoot->getPriority(), searchRoot->getDurationMs());
}
searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
}
Expand Down
23 changes: 15 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/perf_metrics_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,21 @@ def compute_batch_gpu_times(self, requests):

# Compute once per batch, reuse for all requests
if batch_gpu_forward_time is None:
batch_gpu_forward_time = perf.gpu_forward_start_event.elapsed_time(
perf.gpu_forward_end_event
)
batch_gpu_sample_time = (
perf.gpu_forward_end_event.elapsed_time(perf.gpu_sample_end_event)
if perf.gpu_sample_end_event
else 0.0
)
try:
batch_gpu_forward_time = perf.gpu_forward_start_event.elapsed_time(
perf.gpu_forward_end_event
)
batch_gpu_sample_time = (
perf.gpu_forward_end_event.elapsed_time(perf.gpu_sample_end_event)
if perf.gpu_sample_end_event
else 0.0
)
except RuntimeError:
# CUDA event timing can fail if events were not recorded
# on the current stream. Skip metrics for this batch rather
# than crashing the executor thread.
batch_gpu_forward_time = 0.0
batch_gpu_sample_time = 0.0

target["gpu_forward_time"] = batch_gpu_forward_time
target["gpu_sample_time"] = batch_gpu_sample_time
Expand Down