From e5d8513a08f255b64605d5c9aa0da28c4e275239 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Wed, 8 Apr 2026 22:18:24 -0700 Subject: [PATCH] [None][fix] Batch addSequence with pre-claim to fix host offloading MNT overflow When host offloading is enabled, onboarding a host block to GPU during addSequence can trigger eviction of other reusable host blocks from the radix tree. This causes actual KV cache reuse to be less than the scheduler estimated, leading to max_num_tokens (MNT) overflow assertions. Add a new addSequenceBatch API that processes all first-chunk context requests in two phases: - Phase 1: Walk the radix tree and claimBlock() for all matching blocks across all requests. No onboarding, no allocation. This protects reusable blocks from eviction. - Phase 2: Onboard host blocks and allocate non-matching blocks. Since all reusable blocks are already claimed, evictions during onboarding cannot touch them. On the Python side, replace the TOCTOU-prone revalidation loop (count_reusable_blocks + budget check) with a single batch call. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 79 ++++ .../batch_manager/kvCacheManager.cpp | 425 ++++++++++++++++++ .../nanobind/batch_manager/kvCacheManager.cpp | 30 +- .../_torch/pyexecutor/resource_manager.py | 55 ++- 4 files changed, 576 insertions(+), 13 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 06a864511e0..799717cdb07 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -735,6 +735,56 @@ class WindowBlockManager //! \brief Assign blocks for new sequence. Does not try to reuse blocks. void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock); + //! \brief Per-request block allocation statistics from batch addSequence. + struct BatchSeqStats + { + SizeType32 prepopulatedLen{0}; + SizeType32 allocTotalDelta{0}; + SizeType32 allocNewDelta{0}; + SizeType32 reusedDelta{0}; + SizeType32 missedDelta{0}; + }; + + //! \brief Result of Phase 1 (claim-only) of batch addSequence. + //! \details Holds matched blocks and prepared data so Phase 2 can proceed without + //! re-traversing the radix tree. + struct ClaimResult + { + struct ClaimedBlock + { + BlockPtr block; + SizeType32 numMatchedTokens; //!< tokens matched in this block + bool isPartialMatch; + bool needsCopy; //!< partial match on block with refs or non-leaf (needs getFreeBlock + copy in Phase 2) + bool isPlaceholder; //!< placeholder block (linear attention recurrent states) + }; + + std::vector claimedBlocks; + BlockPtr claimedCopySource; //!< unreferenced non-leaf partial-match source claimed to protect from eviction + SizeType32 totalMatchedTokens{0}; + SizeType32 latestMatchingNonPlaceholderBlockIdx{-1}; + SizeType32 numSharedContextBlocks{0}; + SizeType32 numContextBlocks{0}; + bool shareLastContextBlockAmongBeams{true}; + std::vector blockKeys; + std::vector perBlockRetentions; + executor::KvCacheTransferMode mode{executor::KvCacheTransferMode::DRAM}; + std::string directory; + }; + + //! \brief Batch add sequences with two-phase claim-then-onboard under a single lock. + //! \details Phase 1 claims all matching blocks across all requests (protecting from eviction). + //! Phase 2 onboards host blocks and allocates non-matching blocks. + //! The mCachedBlocksRootMutex is held for the entire operation. + //! \param sequences Per-request GenerationRequest references (parallel with other vectors). + //! \param inputLengths Per-request effective input length. + //! \param numContextBlocksVec Per-request number of context blocks. + //! \param llmRequests Per-request LlmRequest references. + //! \return Per-request prepopulatedPromptLen. + [[nodiscard]] std::vector addSequenceBatch(std::vector const& sequences, + std::vector const& inputLengths, std::vector const& numContextBlocksVec, + std::vector> const& llmRequests); + //! \brief Allocate new block for each beam of the sequence. //! \details Might free cached blocks if no free blocks are available. void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams); @@ -1048,6 +1098,16 @@ class WindowBlockManager bool shareLastContextBlockAmongBeams, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); + //! \brief Phase 1 (lock-free): Walk radix tree and claim matching blocks. + //! \details Caller must hold mCachedBlocksRootMutex. + [[nodiscard]] ClaimResult claimMatchingBlocks( + GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest); + + //! \brief Phase 2 (lock-free): Onboard claimed host blocks and allocate non-matching blocks. + //! \details Caller must hold mCachedBlocksRootMutex. + [[nodiscard]] SizeType32 onboardAndAllocateBlocks( + GenerationRequest& sequence, LlmRequest& llmRequest, ClaimResult& claimResult); + //! \brief Free block and all it's descendants. This makes block a claimed leaf block. void freeChildren(BlockPtr const& block); @@ -1242,6 +1302,12 @@ class BlockManager void addSequence( GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock); + //! \brief Batch add sequences forwarding to WindowBlockManager::addSequenceBatch. + [[nodiscard]] std::vector addSequenceBatch( + std::vector const& sequences, std::vector const& inputLengths, + std::vector const& numContextBlocksVec, + std::vector> const& llmRequests, SizeType32 windowSize); + void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize); //! \brief According to request's current position, copy data from the last full block to the next block (ignoring @@ -1732,6 +1798,15 @@ class BaseKVCacheManager OptionalRef llmRequest = std::nullopt) = 0; + //! \brief Batch add sequences with two-phase claim-then-onboard to prevent host offloading eviction. + //! \details Phase 1 claims all matching blocks across all requests (protecting them from eviction). + //! Phase 2 onboards host blocks and allocates non-matching blocks. + //! Requires block reuse enabled and single attention window. + virtual void addSequenceBatch( + std::vector> const& requestInfos, + std::vector> const& llmRequests) + = 0; + [[nodiscard]] virtual std::optional removeSequence(LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) = 0; @@ -2102,6 +2177,10 @@ class KVCacheManager : public BaseKVCacheManager void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest = std::nullopt) override; + void addSequenceBatch( + std::vector> const& requestInfos, + std::vector> const& llmRequests) override; + [[nodiscard]] std::optional removeSequence(LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) override; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index a10b42fbd86..42ea58ca9fb 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1391,6 +1391,345 @@ SizeType32 WindowBlockManager::countReusableBlocks( return reusableBlocks; } +WindowBlockManager::ClaimResult WindowBlockManager::claimMatchingBlocks( + GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest) +{ + // NOTE: Caller must hold mCachedBlocksRootMutex. + TLLM_CHECK_WITH_INFO(!(isRecurrentState()) || inputLength == llmRequest.getPromptLen(), + "Recurrent state does not support CP or truncation yet."); + + ClaimResult result; + result.numContextBlocks = numContextBlocks; + + auto const requestId = sequence.getRequestId(); + auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); + TLLM_CHECK(emplaceDone); + + // Prepare block keys — same logic as WindowBlockManager::addSequence lines 1437-1465 + auto constexpr beamIdx = 0; + auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY) + ? llmRequest.getUniqueTokens(beamIdx) + : *(llmRequest.getEncoderUniqueTokens().value()); + + auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, inputLength - 1, mTokensPerBlock, true); + if (inputLength % mTokensPerBlock == 1) + { + blockedUniqueTokens.emplace_back(); + } + + result.blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); + + auto config = llmRequest.getKvCacheRetentionConfig(); + result.perBlockRetentions = config.value_or(executor::KvCacheRetentionConfig()) + .getPerBlockRetentionPriorityDuration(mTokensPerBlock, inputLength); + result.mode = config.value_or(executor::KvCacheRetentionConfig()).getTransferMode(); + result.directory = config.value_or(executor::KvCacheRetentionConfig()).getDirectory(); + + if (result.mode != executor::KvCacheTransferMode::DRAM && result.directory.empty()) + { + TLLM_LOG_WARNING( + "Transfer mode %d specified without directory, falling back to DRAM mode", static_cast(result.mode)); + result.mode = executor::KvCacheTransferMode::DRAM; + } + + TLLM_CHECK(result.perBlockRetentions.size() == static_cast(numContextBlocks)); + + // Phase 1: Walk radix tree, claim matching blocks — no onboard, no getFreeBlock + // NOTE: Caller must hold mCachedBlocksRootMutex. + + // Compute shareLastContextBlockAmongBeams — same logic as WindowBlockManager::addSequence + result.shareLastContextBlockAmongBeams = sequence.getBeamWidth() == 1; + if (isRecurrentState()) + { + result.shareLastContextBlockAmongBeams |= inputLength % mTokensPerBlock == 0; + } + + result.numSharedContextBlocks = result.shareLastContextBlockAmongBeams ? numContextBlocks : numContextBlocks - 1; + auto searchRoot = mCachedBlocksRoot; + auto blockItr = result.blockKeys.begin(); + + for (int bi = 0; bi < result.numSharedContextBlocks; ++bi) + { + auto [partialMatch, numMatched, matchingBlock] = (searchRoot != nullptr && blockItr != result.blockKeys.end()) + ? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse) + : std::make_tuple(false, 0, nullptr); + if (isRecurrentState()) + { + TLLM_CHECK(partialMatch == false); + } + + if (matchingBlock != nullptr + && result.totalMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen()) + { + ClaimResult::ClaimedBlock claimed; + claimed.block = matchingBlock; + claimed.numMatchedTokens + = numMatched > 0 ? numMatched : static_cast(blockItr->uniqueTokens.size()); + claimed.isPartialMatch = partialMatch; + claimed.needsCopy = false; + claimed.isPlaceholder = matchingBlock->isPlaceholder(); + + result.totalMatchedTokens += claimed.numMatchedTokens; + if (!claimed.isPlaceholder) + { + result.latestMatchingNonPlaceholderBlockIdx = bi; + } + + // Priority update event + if (result.perBlockRetentions[bi].retentionPriority.has_value() + && matchingBlock->getPriority() != result.perBlockRetentions[bi].retentionPriority && mEventManager) + { + mEventManager->enqueueUpdatedEvent(tle::KVCacheUpdatedData(matchingBlock->getHash()) + .priorityUpdated(matchingBlock->getPriority(), + *result.perBlockRetentions[bi].retentionPriority), + mWindowSize); + } + + if (partialMatch) + { + if (matchingBlock->hasRefs() || !matchingBlock->isLeaf()) + { + // Block in use or has children — needs copy in Phase 2. + claimed.needsCopy = true; + if (!matchingBlock->hasRefs()) + { + // Unreferenced non-leaf: could be in the free queue and evictable. + // Claim to protect during Phase 2 copies; deferred release at batch end. + mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority, + result.perBlockRetentions[bi].durationMs); + result.claimedCopySource = matchingBlock; + } + } + else + { + // Leaf with no refs — claim it now (freeLeafBlock + claimBlock, no eviction) + freeLeafBlock(matchingBlock); + mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority, + result.perBlockRetentions[bi].durationMs); + } + searchRoot = nullptr; // no matching for following blocks + } + else + { + // Full match — claim block (removes from free queue, protecting from eviction) + searchRoot = matchingBlock; + mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority, + result.perBlockRetentions[bi].durationMs); + } + + result.claimedBlocks.push_back(std::move(claimed)); + ++blockItr; + } + else + { + // No match — stop matching, remaining blocks handled in Phase 2 + break; + } + } + + TLLM_LOG_DEBUG("%s::claimMatchingBlocks for request %lu - Claimed %zu blocks, %d matched tokens", + mLogPrefix.c_str(), sequence.getRequestId(), result.claimedBlocks.size(), result.totalMatchedTokens); + + return result; +} + +SizeType32 WindowBlockManager::onboardAndAllocateBlocks( + GenerationRequest& sequence, LlmRequest& llmRequest, ClaimResult& claimResult) +{ + // NOTE: Caller must hold mCachedBlocksRootMutex. + std::set reusedBlockIds; + auto blockItr = claimResult.blockKeys.begin(); + SizeType32 bi = 0; + + // Process claimed (matched) blocks: onboard + addBlockToAllBeams + for (auto& claimed : claimResult.claimedBlocks) + { + KVCacheBlock::IdType matchingBlockId = claimed.block->getBlockId(); + + if (claimed.isPartialMatch && claimed.needsCopy) + { + // Partial match needing copy: allocate new block, copy from source, use new block + auto newBlock = getFreeBlock(sequence, claimed.block->getPriority(), claimed.block->getDurationMs(), + claimResult.mode, claimResult.directory); + mTransferManager->onboard( + claimed.block, newBlock, mPools, claimed.numMatchedTokens, claimResult.mode, claimResult.directory); + claimed.block = newBlock; + if (blockItr != claimResult.blockKeys.end()) + { + claimed.block->setBlockKey( + *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); + } + claimed.block->setHash(); + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks for request %lu - Copied partially filled block %d", + mLogPrefix.c_str(), sequence.getRequestId(), matchingBlockId); + } + else if (claimed.isPartialMatch) + { + // Partial leaf match — already claimed in Phase 1 + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks for request %lu - Reused partially filled block %d", + mLogPrefix.c_str(), sequence.getRequestId(), matchingBlockId); + } + else + { + // Full match — already claimed in Phase 1 + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks for request %lu - Matched full block %d", mLogPrefix.c_str(), + sequence.getRequestId(), matchingBlockId); + } + + onboardBlock(sequence, claimed.block, claimResult.mode, claimResult.directory); + addBlockToAllBeams(claimed.block, sequence); + if (!claimed.isPlaceholder) + { + ++mReusedBlocks; + if (!reusedBlockIds.count(matchingBlockId)) + { + reusedBlockIds.insert(matchingBlockId); + ++mReusedUniqueBlocks; + } + } + ++blockItr; + ++bi; + } + + // Allocate non-matching shared context blocks + for (; bi < claimResult.numSharedContextBlocks; ++bi) + { + bool shouldAllocate = true; + if (isRecurrentState()) + { + shouldAllocate = mLinearAttentionMetadata->shouldAllocateRecurrentStates( + /*currentBlockEndTokenIdx=*/(bi + 1) * mTokensPerBlock, llmRequest.getPromptLen(), mTokensPerBlock); + TLLM_LOG_DEBUG( + "%s::onboardAndAllocateBlocks - Recurrent state block %d. shouldAllocate=%d for sequence %lu", + mLogPrefix.c_str(), bi, shouldAllocate, sequence.getRequestId()); + } + + auto freeBlock = getFreeBlock(sequence, + claimResult.perBlockRetentions[bi].retentionPriority.value_or( + executor::KvCacheRetentionConfig::kDefaultRetentionPriority), + claimResult.perBlockRetentions[bi].durationMs, claimResult.mode, claimResult.directory, + /*wantPlaceholder=*/!shouldAllocate); + addBlockToAllBeams(freeBlock, sequence); + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks - No match, allocated new block %d for sequence %lu", + mLogPrefix.c_str(), freeBlock->getBlockId(), sequence.getRequestId()); + if (blockItr != claimResult.blockKeys.end()) + { + freeBlock->setBlockKey(*blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); + ++blockItr; + } + freeBlock->setHash(); + ++mMissedBlocks; + } + + // Allocate non-shared last blocks (beam > 1) + auto const beamWidth = sequence.getBeamWidth(); + for (int nbi = claimResult.numSharedContextBlocks; nbi < claimResult.numContextBlocks; ++nbi) + { + for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) + { + auto freeBlock = getFreeBlock(sequence, + claimResult.perBlockRetentions[nbi].retentionPriority.value_or( + executor::KvCacheRetentionConfig::kDefaultRetentionPriority), + claimResult.perBlockRetentions[nbi].durationMs, claimResult.mode, claimResult.directory); + addBlockToBeam(freeBlock, sequence, beamIdx); + if (blockItr != claimResult.blockKeys.end()) + { + freeBlock->setBlockKey( + *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); + ++blockItr; + } + freeBlock->setHash(); + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d", + mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), nbi); + } + ++mMissedBlocks; + if (blockItr != claimResult.blockKeys.end()) + { + ++blockItr; + } + } + + // Finalize matched token count (purge trailing placeholders for recurrent states) + auto numMatchedTokens = claimResult.totalMatchedTokens; + if (isRecurrentState()) + { + numMatchedTokens = (claimResult.latestMatchingNonPlaceholderBlockIdx + 1) * mTokensPerBlock; + } + sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens); + + // Update stats and return prepopulated length + mReusedTokens += static_cast(numMatchedTokens); + auto constexpr beamIdx = 0; + auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY) + ? llmRequest.getUniqueTokens(beamIdx) + : *(llmRequest.getEncoderUniqueTokens().value()); + mTotalInputTokens += static_cast(uniqueTokens.size()); + + SizeType32 numConnectorMatchedTokens = 0; + if (mKvCacheConnectorManager && !llmRequest.isDummyRequest()) + { + numConnectorMatchedTokens + = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, sequence.getCurrentPrepopulatedPromptLen()); + } + + auto const totalPrepopulatedLen = sequence.getCurrentPrepopulatedPromptLen() + numConnectorMatchedTokens; + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks: Request %lu, prepopulatedPromptLen %d, numConnectorMatchedTokens %d", + mLogPrefix.c_str(), llmRequest.mRequestId, sequence.getCurrentPrepopulatedPromptLen(), + numConnectorMatchedTokens); + return totalPrepopulatedLen; +} + +std::vector WindowBlockManager::addSequenceBatch( + std::vector const& sequences, std::vector const& inputLengths, + std::vector const& numContextBlocksVec, + std::vector> const& llmRequests) +{ + auto const n = sequences.size(); + std::vector claimResults(n); + std::vector results(n); + + // Hold the lock for the entire two-phase operation. + std::lock_guard lock(mCachedBlocksRootMutex); + + // Phase 1: Claim all matching blocks across all requests. + for (size_t i = 0; i < n; ++i) + { + claimResults[i] + = claimMatchingBlocks(*sequences[i], inputLengths[i], numContextBlocksVec[i], llmRequests[i].get()); + } + + // Phase 2: Onboard + allocate for each request, snapshotting stats between requests. + for (size_t i = 0; i < n; ++i) + { + SizeType32 const preTotalBlocks = mAllocTotalBlocks; + SizeType32 const preNewBlocks = mAllocNewBlocks; + SizeType32 const preReused = mReusedBlocks; + SizeType32 const preMissed = mMissedBlocks; + + results[i].prepopulatedLen = onboardAndAllocateBlocks(*sequences[i], llmRequests[i].get(), claimResults[i]); + + results[i].allocTotalDelta = mAllocTotalBlocks - preTotalBlocks; + results[i].allocNewDelta = mAllocNewBlocks - preNewBlocks; + results[i].reusedDelta = mReusedBlocks - preReused; + results[i].missedDelta = mMissedBlocks - preMissed; + } + + // Deferred release: return claimed partial-match copy sources to the free queue + // if they still have no refs. Deduplicate by block ID to avoid double-release + // when multiple requests partial-matched the same source. + std::set releasedSourceIds; + for (auto const& cr : claimResults) + { + if (cr.claimedCopySource && releasedSourceIds.insert(cr.claimedCopySource->getBlockId()).second + && !cr.claimedCopySource->hasRefs()) + { + mEvictionPolicy->releaseBlock(cr.claimedCopySource); + } + } + + return results; +} + bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) { return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr; @@ -1631,6 +1970,15 @@ SizeType32 BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inp return mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); } +std::vector BlockManager::addSequenceBatch( + std::vector const& sequences, std::vector const& inputLengths, + std::vector const& numContextBlocksVec, + std::vector> const& llmRequests, SizeType32 windowSize) +{ + return mWindowBlockManagers.at(windowSize) + .addSequenceBatch(sequences, inputLengths, numContextBlocksVec, llmRequests); +} + // There are two versions of WindowBlockManager::addSequence function. // This is called when block reuse is enabled. // Returns the total prepopulatedPromptLen (including connector matched tokens) for this window. @@ -3073,6 +3421,83 @@ void KVCacheManager::addSequence( } } +void KVCacheManager::addSequenceBatch( + std::vector> const& requestInfos, + std::vector> const& llmRequests) +{ + TLLM_CHECK(requestInfos.size() == llmRequests.size()); + TLLM_CHECK_WITH_INFO( + !mBlockManager.isVariableWindow(), "addSequenceBatch does not support variable window attention"); + if (requestInfos.empty()) + { + return; + } + TLLM_CHECK_WITH_INFO(mEnableBlockReuse, "addSequenceBatch requires block reuse to be enabled"); + + auto const& [firstWindowSize, firstMetadata] = *mBlockManager.getWindowSizesMetadata().begin(); + + auto const n = requestInfos.size(); + + // --- Setup: create sequences, hold them, compute effective input length --- + std::vector sequences(n); + std::vector inputLengths(n); + std::vector numContextBlocksVec(n); + + for (size_t i = 0; i < n; ++i) + { + auto const requestId = std::get<0>(requestInfos[i]); + auto const inputLength = std::get<1>(requestInfos[i]); + auto const beamWidth = std::get<2>(requestInfos[i]); + auto& llmRequest = llmRequests[i].get(); + + auto kvCacheRetentionConfig + = llmRequest.getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig()); + + auto const [seqIt, emplaceDone] = [&] + { + auto lck = std::scoped_lock(mSequencesMtx); + return mSequences.try_emplace(requestId, requestId, inputLength, beamWidth, + mBlockManager.getWindowSizesMetadata(), kvCacheRetentionConfig); + }(); + TLLM_CHECK(emplaceDone); + + sequences[i] = &seqIt->second; + + if (!mBlockManager.isSequenceHeld(requestId)) + { + mBlockManager.holdSequence(requestId); + } + + auto const maxTokenNum = firstMetadata.maxTokenNum; + auto const temporaryAttentionWindow = firstMetadata.temporaryAttentionWindow; + inputLengths[i] = std::min(inputLength, maxTokenNum + temporaryAttentionWindow); + numContextBlocksVec[i] = tc::ceilDiv(inputLengths[i], getTokensPerBlock()); + } + + // --- Two-phase claim-then-onboard under a single lock --- + auto const batchResults + = mBlockManager.addSequenceBatch(sequences, inputLengths, numContextBlocksVec, llmRequests, firstWindowSize); + + // --- Finalize: update offsets, set prepopulated length, update per-request stats --- + for (size_t i = 0; i < n; ++i) + { + auto& llmRequest = llmRequests[i].get(); + auto const& stats = batchResults[i]; + + mBlockManager.updateSequenceCacheBlockOffsets(*sequences[i], firstWindowSize); + + TLLM_LOG_DEBUG("KVCacheManager::addSequenceBatch: Setting prepopulatedPromptLen to %d for request %lu", + stats.prepopulatedLen, llmRequest.mRequestId); + llmRequest.setPrepopulatedPromptLen(stats.prepopulatedLen, getTokensPerBlock()); + llmRequest.setEstimatedReusableTokens(0); + + llmRequest.updateAllocTotalBlocksPerRequest(stats.allocTotalDelta); + llmRequest.updateAllocNewBlocksPerRequest(stats.allocNewDelta); + llmRequest.updateReusedBlocksPerRequest(stats.reusedDelta); + llmRequest.updateMissedBlocksPerRequest(stats.missedDelta); + } +} + void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) { auto const requestId = llmRequest.mRequestId; diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index fe7687aee77..75bd34cbdf1 100755 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -68,7 +68,7 @@ std::optional from_torch(std::optiona class PyKvCacheManager : public tbk::BaseKVCacheManager { public: - NB_TRAMPOLINE(tbk::BaseKVCacheManager, 36); + NB_TRAMPOLINE(tbk::BaseKVCacheManager, 37); // using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors void allocatePools(bool useUvm = false) override @@ -122,6 +122,13 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); } + void addSequenceBatch( + std::vector> const& requestInfos, + std::vector> const& llmRequests) override + { + NB_OVERRIDE_PURE(addSequenceBatch, requestInfos, llmRequests); + } + std::optional removeSequence(tb::LlmRequest::RequestIdType requestId, tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) override @@ -393,6 +400,27 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("add_token", &BaseKVCacheManager::addToken, nb::call_guard()) .def("get_token_count", &BaseKVCacheManager::getTokenCount, nb::arg("request_id")) .def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard()) + .def( + "add_sequence_batch", + [](tbk::BaseKVCacheManager& self, nb::list requestInfosList, nb::list llmRequestsList) + { + // Marshal Python inputs while GIL is held. + std::vector> requestInfos; + std::vector> llmRequests; + requestInfos.reserve(nb::len(requestInfosList)); + llmRequests.reserve(nb::len(llmRequestsList)); + for (size_t i = 0; i < nb::len(requestInfosList); ++i) + { + auto info = nb::cast(requestInfosList[i]); + requestInfos.emplace_back(nb::cast(info[0]), + nb::cast(info[1]), nb::cast(info[2])); + llmRequests.push_back(std::ref(nb::cast(llmRequestsList[i]))); + } + // Release GIL only for the C++ call. + nb::gil_scoped_release release; + self.addSequenceBatch(requestInfos, llmRequests); + }, + nb::arg("request_infos"), nb::arg("llm_requests")) .def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard()) .def("pin_blocks", &BaseKVCacheManager::pinBlocks, nb::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 971ff8c5402..aebd8b96f07 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -619,6 +619,14 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): # wait for all pending work to finish before launching offload/onboarding/partial copy self.impl.sync_transfer_manager_with_buffer_manager() + # Collect first-chunk requests eligible for batch add_sequence. + # When block reuse is enabled, addSequenceBatch uses a two-phase + # claim-then-onboard strategy that prevents host offloading from + # evicting reusable blocks in the radix tree. + batch_request_infos = [] + batch_llm_requests = [] + batch_ctx_requests = [] + # allocate KV Cache for req in scheduled_batch.context_requests: req_beam_width = req.sampling_config.beam_width @@ -635,18 +643,41 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): else: if req.is_first_context_chunk and self._kv_connector_should_add_sequence( req): - self.impl.add_sequence(req.py_request_id, - req.prompt_len, req_beam_width, - req) - for _ in range(self.num_extra_kv_tokens): - self.impl.add_token(req.py_request_id) - for _ in range(get_draft_token_length(req)): - self.impl.add_token(req.py_request_id) - - if self.kv_connector_manager is not None: - block_ids = self.get_cache_indices(req) - self.kv_connector_manager.update_state_after_alloc( - req, block_ids) + if self.enable_block_reuse and not self.is_vswa: + # Batch path: two-phase claim-then-onboard + # (not supported for VSWA which needs multi-window addSequence) + batch_request_infos.append( + (req.py_request_id, req.prompt_len, + req_beam_width)) + batch_llm_requests.append(req) + batch_ctx_requests.append(req) + else: + self.impl.add_sequence(req.py_request_id, + req.prompt_len, + req_beam_width, req) + for _ in range(self.num_extra_kv_tokens): + self.impl.add_token(req.py_request_id) + for _ in range(get_draft_token_length(req)): + self.impl.add_token(req.py_request_id) + + if self.kv_connector_manager is not None: + block_ids = self.get_cache_indices(req) + self.kv_connector_manager.update_state_after_alloc( + req, block_ids) + + if batch_request_infos: + self.impl.add_sequence_batch(batch_request_infos, + batch_llm_requests) + for req in batch_ctx_requests: + for _ in range(self.num_extra_kv_tokens): + self.impl.add_token(req.py_request_id) + for _ in range(get_draft_token_length(req)): + self.impl.add_token(req.py_request_id) + + if self.kv_connector_manager is not None: + block_ids = self.get_cache_indices(req) + self.kv_connector_manager.update_state_after_alloc( + req, block_ids) # A request may change from `context_requests_chunking` to `context_requests_last_chunk` in `add_sequence` due to KV cache reuse, so we rebuild the context request lists here. scheduled_batch.reset_context_requests()