diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 06a864511e0..799717cdb07 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -735,6 +735,56 @@ class WindowBlockManager //! \brief Assign blocks for new sequence. Does not try to reuse blocks. void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock); + //! \brief Per-request block allocation statistics from batch addSequence. + struct BatchSeqStats + { + SizeType32 prepopulatedLen{0}; + SizeType32 allocTotalDelta{0}; + SizeType32 allocNewDelta{0}; + SizeType32 reusedDelta{0}; + SizeType32 missedDelta{0}; + }; + + //! \brief Result of Phase 1 (claim-only) of batch addSequence. + //! \details Holds matched blocks and prepared data so Phase 2 can proceed without + //! re-traversing the radix tree. + struct ClaimResult + { + struct ClaimedBlock + { + BlockPtr block; + SizeType32 numMatchedTokens; //!< tokens matched in this block + bool isPartialMatch; + bool needsCopy; //!< partial match on block with refs or non-leaf (needs getFreeBlock + copy in Phase 2) + bool isPlaceholder; //!< placeholder block (linear attention recurrent states) + }; + + std::vector claimedBlocks; + BlockPtr claimedCopySource; //!< unreferenced non-leaf partial-match source claimed to protect from eviction + SizeType32 totalMatchedTokens{0}; + SizeType32 latestMatchingNonPlaceholderBlockIdx{-1}; + SizeType32 numSharedContextBlocks{0}; + SizeType32 numContextBlocks{0}; + bool shareLastContextBlockAmongBeams{true}; + std::vector blockKeys; + std::vector perBlockRetentions; + executor::KvCacheTransferMode mode{executor::KvCacheTransferMode::DRAM}; + std::string directory; + }; + + //! \brief Batch add sequences with two-phase claim-then-onboard under a single lock. + //! \details Phase 1 claims all matching blocks across all requests (protecting from eviction). + //! Phase 2 onboards host blocks and allocates non-matching blocks. + //! The mCachedBlocksRootMutex is held for the entire operation. + //! \param sequences Per-request GenerationRequest references (parallel with other vectors). + //! \param inputLengths Per-request effective input length. + //! \param numContextBlocksVec Per-request number of context blocks. + //! \param llmRequests Per-request LlmRequest references. + //! \return Per-request prepopulatedPromptLen. + [[nodiscard]] std::vector addSequenceBatch(std::vector const& sequences, + std::vector const& inputLengths, std::vector const& numContextBlocksVec, + std::vector> const& llmRequests); + //! \brief Allocate new block for each beam of the sequence. //! \details Might free cached blocks if no free blocks are available. void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams); @@ -1048,6 +1098,16 @@ class WindowBlockManager bool shareLastContextBlockAmongBeams, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); + //! \brief Phase 1 (lock-free): Walk radix tree and claim matching blocks. + //! \details Caller must hold mCachedBlocksRootMutex. + [[nodiscard]] ClaimResult claimMatchingBlocks( + GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest); + + //! \brief Phase 2 (lock-free): Onboard claimed host blocks and allocate non-matching blocks. + //! \details Caller must hold mCachedBlocksRootMutex. + [[nodiscard]] SizeType32 onboardAndAllocateBlocks( + GenerationRequest& sequence, LlmRequest& llmRequest, ClaimResult& claimResult); + //! \brief Free block and all it's descendants. This makes block a claimed leaf block. void freeChildren(BlockPtr const& block); @@ -1242,6 +1302,12 @@ class BlockManager void addSequence( GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock); + //! \brief Batch add sequences forwarding to WindowBlockManager::addSequenceBatch. + [[nodiscard]] std::vector addSequenceBatch( + std::vector const& sequences, std::vector const& inputLengths, + std::vector const& numContextBlocksVec, + std::vector> const& llmRequests, SizeType32 windowSize); + void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize); //! \brief According to request's current position, copy data from the last full block to the next block (ignoring @@ -1732,6 +1798,15 @@ class BaseKVCacheManager OptionalRef llmRequest = std::nullopt) = 0; + //! \brief Batch add sequences with two-phase claim-then-onboard to prevent host offloading eviction. + //! \details Phase 1 claims all matching blocks across all requests (protecting them from eviction). + //! Phase 2 onboards host blocks and allocates non-matching blocks. + //! Requires block reuse enabled and single attention window. + virtual void addSequenceBatch( + std::vector> const& requestInfos, + std::vector> const& llmRequests) + = 0; + [[nodiscard]] virtual std::optional removeSequence(LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) = 0; @@ -2102,6 +2177,10 @@ class KVCacheManager : public BaseKVCacheManager void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest = std::nullopt) override; + void addSequenceBatch( + std::vector> const& requestInfos, + std::vector> const& llmRequests) override; + [[nodiscard]] std::optional removeSequence(LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) override; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index a10b42fbd86..42ea58ca9fb 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1391,6 +1391,345 @@ SizeType32 WindowBlockManager::countReusableBlocks( return reusableBlocks; } +WindowBlockManager::ClaimResult WindowBlockManager::claimMatchingBlocks( + GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest) +{ + // NOTE: Caller must hold mCachedBlocksRootMutex. + TLLM_CHECK_WITH_INFO(!(isRecurrentState()) || inputLength == llmRequest.getPromptLen(), + "Recurrent state does not support CP or truncation yet."); + + ClaimResult result; + result.numContextBlocks = numContextBlocks; + + auto const requestId = sequence.getRequestId(); + auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); + TLLM_CHECK(emplaceDone); + + // Prepare block keys — same logic as WindowBlockManager::addSequence lines 1437-1465 + auto constexpr beamIdx = 0; + auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY) + ? llmRequest.getUniqueTokens(beamIdx) + : *(llmRequest.getEncoderUniqueTokens().value()); + + auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, inputLength - 1, mTokensPerBlock, true); + if (inputLength % mTokensPerBlock == 1) + { + blockedUniqueTokens.emplace_back(); + } + + result.blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); + + auto config = llmRequest.getKvCacheRetentionConfig(); + result.perBlockRetentions = config.value_or(executor::KvCacheRetentionConfig()) + .getPerBlockRetentionPriorityDuration(mTokensPerBlock, inputLength); + result.mode = config.value_or(executor::KvCacheRetentionConfig()).getTransferMode(); + result.directory = config.value_or(executor::KvCacheRetentionConfig()).getDirectory(); + + if (result.mode != executor::KvCacheTransferMode::DRAM && result.directory.empty()) + { + TLLM_LOG_WARNING( + "Transfer mode %d specified without directory, falling back to DRAM mode", static_cast(result.mode)); + result.mode = executor::KvCacheTransferMode::DRAM; + } + + TLLM_CHECK(result.perBlockRetentions.size() == static_cast(numContextBlocks)); + + // Phase 1: Walk radix tree, claim matching blocks — no onboard, no getFreeBlock + // NOTE: Caller must hold mCachedBlocksRootMutex. + + // Compute shareLastContextBlockAmongBeams — same logic as WindowBlockManager::addSequence + result.shareLastContextBlockAmongBeams = sequence.getBeamWidth() == 1; + if (isRecurrentState()) + { + result.shareLastContextBlockAmongBeams |= inputLength % mTokensPerBlock == 0; + } + + result.numSharedContextBlocks = result.shareLastContextBlockAmongBeams ? numContextBlocks : numContextBlocks - 1; + auto searchRoot = mCachedBlocksRoot; + auto blockItr = result.blockKeys.begin(); + + for (int bi = 0; bi < result.numSharedContextBlocks; ++bi) + { + auto [partialMatch, numMatched, matchingBlock] = (searchRoot != nullptr && blockItr != result.blockKeys.end()) + ? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse) + : std::make_tuple(false, 0, nullptr); + if (isRecurrentState()) + { + TLLM_CHECK(partialMatch == false); + } + + if (matchingBlock != nullptr + && result.totalMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen()) + { + ClaimResult::ClaimedBlock claimed; + claimed.block = matchingBlock; + claimed.numMatchedTokens + = numMatched > 0 ? numMatched : static_cast(blockItr->uniqueTokens.size()); + claimed.isPartialMatch = partialMatch; + claimed.needsCopy = false; + claimed.isPlaceholder = matchingBlock->isPlaceholder(); + + result.totalMatchedTokens += claimed.numMatchedTokens; + if (!claimed.isPlaceholder) + { + result.latestMatchingNonPlaceholderBlockIdx = bi; + } + + // Priority update event + if (result.perBlockRetentions[bi].retentionPriority.has_value() + && matchingBlock->getPriority() != result.perBlockRetentions[bi].retentionPriority && mEventManager) + { + mEventManager->enqueueUpdatedEvent(tle::KVCacheUpdatedData(matchingBlock->getHash()) + .priorityUpdated(matchingBlock->getPriority(), + *result.perBlockRetentions[bi].retentionPriority), + mWindowSize); + } + + if (partialMatch) + { + if (matchingBlock->hasRefs() || !matchingBlock->isLeaf()) + { + // Block in use or has children — needs copy in Phase 2. + claimed.needsCopy = true; + if (!matchingBlock->hasRefs()) + { + // Unreferenced non-leaf: could be in the free queue and evictable. + // Claim to protect during Phase 2 copies; deferred release at batch end. + mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority, + result.perBlockRetentions[bi].durationMs); + result.claimedCopySource = matchingBlock; + } + } + else + { + // Leaf with no refs — claim it now (freeLeafBlock + claimBlock, no eviction) + freeLeafBlock(matchingBlock); + mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority, + result.perBlockRetentions[bi].durationMs); + } + searchRoot = nullptr; // no matching for following blocks + } + else + { + // Full match — claim block (removes from free queue, protecting from eviction) + searchRoot = matchingBlock; + mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority, + result.perBlockRetentions[bi].durationMs); + } + + result.claimedBlocks.push_back(std::move(claimed)); + ++blockItr; + } + else + { + // No match — stop matching, remaining blocks handled in Phase 2 + break; + } + } + + TLLM_LOG_DEBUG("%s::claimMatchingBlocks for request %lu - Claimed %zu blocks, %d matched tokens", + mLogPrefix.c_str(), sequence.getRequestId(), result.claimedBlocks.size(), result.totalMatchedTokens); + + return result; +} + +SizeType32 WindowBlockManager::onboardAndAllocateBlocks( + GenerationRequest& sequence, LlmRequest& llmRequest, ClaimResult& claimResult) +{ + // NOTE: Caller must hold mCachedBlocksRootMutex. + std::set reusedBlockIds; + auto blockItr = claimResult.blockKeys.begin(); + SizeType32 bi = 0; + + // Process claimed (matched) blocks: onboard + addBlockToAllBeams + for (auto& claimed : claimResult.claimedBlocks) + { + KVCacheBlock::IdType matchingBlockId = claimed.block->getBlockId(); + + if (claimed.isPartialMatch && claimed.needsCopy) + { + // Partial match needing copy: allocate new block, copy from source, use new block + auto newBlock = getFreeBlock(sequence, claimed.block->getPriority(), claimed.block->getDurationMs(), + claimResult.mode, claimResult.directory); + mTransferManager->onboard( + claimed.block, newBlock, mPools, claimed.numMatchedTokens, claimResult.mode, claimResult.directory); + claimed.block = newBlock; + if (blockItr != claimResult.blockKeys.end()) + { + claimed.block->setBlockKey( + *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); + } + claimed.block->setHash(); + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks for request %lu - Copied partially filled block %d", + mLogPrefix.c_str(), sequence.getRequestId(), matchingBlockId); + } + else if (claimed.isPartialMatch) + { + // Partial leaf match — already claimed in Phase 1 + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks for request %lu - Reused partially filled block %d", + mLogPrefix.c_str(), sequence.getRequestId(), matchingBlockId); + } + else + { + // Full match — already claimed in Phase 1 + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks for request %lu - Matched full block %d", mLogPrefix.c_str(), + sequence.getRequestId(), matchingBlockId); + } + + onboardBlock(sequence, claimed.block, claimResult.mode, claimResult.directory); + addBlockToAllBeams(claimed.block, sequence); + if (!claimed.isPlaceholder) + { + ++mReusedBlocks; + if (!reusedBlockIds.count(matchingBlockId)) + { + reusedBlockIds.insert(matchingBlockId); + ++mReusedUniqueBlocks; + } + } + ++blockItr; + ++bi; + } + + // Allocate non-matching shared context blocks + for (; bi < claimResult.numSharedContextBlocks; ++bi) + { + bool shouldAllocate = true; + if (isRecurrentState()) + { + shouldAllocate = mLinearAttentionMetadata->shouldAllocateRecurrentStates( + /*currentBlockEndTokenIdx=*/(bi + 1) * mTokensPerBlock, llmRequest.getPromptLen(), mTokensPerBlock); + TLLM_LOG_DEBUG( + "%s::onboardAndAllocateBlocks - Recurrent state block %d. shouldAllocate=%d for sequence %lu", + mLogPrefix.c_str(), bi, shouldAllocate, sequence.getRequestId()); + } + + auto freeBlock = getFreeBlock(sequence, + claimResult.perBlockRetentions[bi].retentionPriority.value_or( + executor::KvCacheRetentionConfig::kDefaultRetentionPriority), + claimResult.perBlockRetentions[bi].durationMs, claimResult.mode, claimResult.directory, + /*wantPlaceholder=*/!shouldAllocate); + addBlockToAllBeams(freeBlock, sequence); + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks - No match, allocated new block %d for sequence %lu", + mLogPrefix.c_str(), freeBlock->getBlockId(), sequence.getRequestId()); + if (blockItr != claimResult.blockKeys.end()) + { + freeBlock->setBlockKey(*blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); + ++blockItr; + } + freeBlock->setHash(); + ++mMissedBlocks; + } + + // Allocate non-shared last blocks (beam > 1) + auto const beamWidth = sequence.getBeamWidth(); + for (int nbi = claimResult.numSharedContextBlocks; nbi < claimResult.numContextBlocks; ++nbi) + { + for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) + { + auto freeBlock = getFreeBlock(sequence, + claimResult.perBlockRetentions[nbi].retentionPriority.value_or( + executor::KvCacheRetentionConfig::kDefaultRetentionPriority), + claimResult.perBlockRetentions[nbi].durationMs, claimResult.mode, claimResult.directory); + addBlockToBeam(freeBlock, sequence, beamIdx); + if (blockItr != claimResult.blockKeys.end()) + { + freeBlock->setBlockKey( + *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); + ++blockItr; + } + freeBlock->setHash(); + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d", + mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), nbi); + } + ++mMissedBlocks; + if (blockItr != claimResult.blockKeys.end()) + { + ++blockItr; + } + } + + // Finalize matched token count (purge trailing placeholders for recurrent states) + auto numMatchedTokens = claimResult.totalMatchedTokens; + if (isRecurrentState()) + { + numMatchedTokens = (claimResult.latestMatchingNonPlaceholderBlockIdx + 1) * mTokensPerBlock; + } + sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens); + + // Update stats and return prepopulated length + mReusedTokens += static_cast(numMatchedTokens); + auto constexpr beamIdx = 0; + auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY) + ? llmRequest.getUniqueTokens(beamIdx) + : *(llmRequest.getEncoderUniqueTokens().value()); + mTotalInputTokens += static_cast(uniqueTokens.size()); + + SizeType32 numConnectorMatchedTokens = 0; + if (mKvCacheConnectorManager && !llmRequest.isDummyRequest()) + { + numConnectorMatchedTokens + = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, sequence.getCurrentPrepopulatedPromptLen()); + } + + auto const totalPrepopulatedLen = sequence.getCurrentPrepopulatedPromptLen() + numConnectorMatchedTokens; + TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks: Request %lu, prepopulatedPromptLen %d, numConnectorMatchedTokens %d", + mLogPrefix.c_str(), llmRequest.mRequestId, sequence.getCurrentPrepopulatedPromptLen(), + numConnectorMatchedTokens); + return totalPrepopulatedLen; +} + +std::vector WindowBlockManager::addSequenceBatch( + std::vector const& sequences, std::vector const& inputLengths, + std::vector const& numContextBlocksVec, + std::vector> const& llmRequests) +{ + auto const n = sequences.size(); + std::vector claimResults(n); + std::vector results(n); + + // Hold the lock for the entire two-phase operation. + std::lock_guard lock(mCachedBlocksRootMutex); + + // Phase 1: Claim all matching blocks across all requests. + for (size_t i = 0; i < n; ++i) + { + claimResults[i] + = claimMatchingBlocks(*sequences[i], inputLengths[i], numContextBlocksVec[i], llmRequests[i].get()); + } + + // Phase 2: Onboard + allocate for each request, snapshotting stats between requests. + for (size_t i = 0; i < n; ++i) + { + SizeType32 const preTotalBlocks = mAllocTotalBlocks; + SizeType32 const preNewBlocks = mAllocNewBlocks; + SizeType32 const preReused = mReusedBlocks; + SizeType32 const preMissed = mMissedBlocks; + + results[i].prepopulatedLen = onboardAndAllocateBlocks(*sequences[i], llmRequests[i].get(), claimResults[i]); + + results[i].allocTotalDelta = mAllocTotalBlocks - preTotalBlocks; + results[i].allocNewDelta = mAllocNewBlocks - preNewBlocks; + results[i].reusedDelta = mReusedBlocks - preReused; + results[i].missedDelta = mMissedBlocks - preMissed; + } + + // Deferred release: return claimed partial-match copy sources to the free queue + // if they still have no refs. Deduplicate by block ID to avoid double-release + // when multiple requests partial-matched the same source. + std::set releasedSourceIds; + for (auto const& cr : claimResults) + { + if (cr.claimedCopySource && releasedSourceIds.insert(cr.claimedCopySource->getBlockId()).second + && !cr.claimedCopySource->hasRefs()) + { + mEvictionPolicy->releaseBlock(cr.claimedCopySource); + } + } + + return results; +} + bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) { return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr; @@ -1631,6 +1970,15 @@ SizeType32 BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inp return mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); } +std::vector BlockManager::addSequenceBatch( + std::vector const& sequences, std::vector const& inputLengths, + std::vector const& numContextBlocksVec, + std::vector> const& llmRequests, SizeType32 windowSize) +{ + return mWindowBlockManagers.at(windowSize) + .addSequenceBatch(sequences, inputLengths, numContextBlocksVec, llmRequests); +} + // There are two versions of WindowBlockManager::addSequence function. // This is called when block reuse is enabled. // Returns the total prepopulatedPromptLen (including connector matched tokens) for this window. @@ -3073,6 +3421,83 @@ void KVCacheManager::addSequence( } } +void KVCacheManager::addSequenceBatch( + std::vector> const& requestInfos, + std::vector> const& llmRequests) +{ + TLLM_CHECK(requestInfos.size() == llmRequests.size()); + TLLM_CHECK_WITH_INFO( + !mBlockManager.isVariableWindow(), "addSequenceBatch does not support variable window attention"); + if (requestInfos.empty()) + { + return; + } + TLLM_CHECK_WITH_INFO(mEnableBlockReuse, "addSequenceBatch requires block reuse to be enabled"); + + auto const& [firstWindowSize, firstMetadata] = *mBlockManager.getWindowSizesMetadata().begin(); + + auto const n = requestInfos.size(); + + // --- Setup: create sequences, hold them, compute effective input length --- + std::vector sequences(n); + std::vector inputLengths(n); + std::vector numContextBlocksVec(n); + + for (size_t i = 0; i < n; ++i) + { + auto const requestId = std::get<0>(requestInfos[i]); + auto const inputLength = std::get<1>(requestInfos[i]); + auto const beamWidth = std::get<2>(requestInfos[i]); + auto& llmRequest = llmRequests[i].get(); + + auto kvCacheRetentionConfig + = llmRequest.getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig()); + + auto const [seqIt, emplaceDone] = [&] + { + auto lck = std::scoped_lock(mSequencesMtx); + return mSequences.try_emplace(requestId, requestId, inputLength, beamWidth, + mBlockManager.getWindowSizesMetadata(), kvCacheRetentionConfig); + }(); + TLLM_CHECK(emplaceDone); + + sequences[i] = &seqIt->second; + + if (!mBlockManager.isSequenceHeld(requestId)) + { + mBlockManager.holdSequence(requestId); + } + + auto const maxTokenNum = firstMetadata.maxTokenNum; + auto const temporaryAttentionWindow = firstMetadata.temporaryAttentionWindow; + inputLengths[i] = std::min(inputLength, maxTokenNum + temporaryAttentionWindow); + numContextBlocksVec[i] = tc::ceilDiv(inputLengths[i], getTokensPerBlock()); + } + + // --- Two-phase claim-then-onboard under a single lock --- + auto const batchResults + = mBlockManager.addSequenceBatch(sequences, inputLengths, numContextBlocksVec, llmRequests, firstWindowSize); + + // --- Finalize: update offsets, set prepopulated length, update per-request stats --- + for (size_t i = 0; i < n; ++i) + { + auto& llmRequest = llmRequests[i].get(); + auto const& stats = batchResults[i]; + + mBlockManager.updateSequenceCacheBlockOffsets(*sequences[i], firstWindowSize); + + TLLM_LOG_DEBUG("KVCacheManager::addSequenceBatch: Setting prepopulatedPromptLen to %d for request %lu", + stats.prepopulatedLen, llmRequest.mRequestId); + llmRequest.setPrepopulatedPromptLen(stats.prepopulatedLen, getTokensPerBlock()); + llmRequest.setEstimatedReusableTokens(0); + + llmRequest.updateAllocTotalBlocksPerRequest(stats.allocTotalDelta); + llmRequest.updateAllocNewBlocksPerRequest(stats.allocNewDelta); + llmRequest.updateReusedBlocksPerRequest(stats.reusedDelta); + llmRequest.updateMissedBlocksPerRequest(stats.missedDelta); + } +} + void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) { auto const requestId = llmRequest.mRequestId; diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index fe7687aee77..75bd34cbdf1 100755 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -68,7 +68,7 @@ std::optional from_torch(std::optiona class PyKvCacheManager : public tbk::BaseKVCacheManager { public: - NB_TRAMPOLINE(tbk::BaseKVCacheManager, 36); + NB_TRAMPOLINE(tbk::BaseKVCacheManager, 37); // using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors void allocatePools(bool useUvm = false) override @@ -122,6 +122,13 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); } + void addSequenceBatch( + std::vector> const& requestInfos, + std::vector> const& llmRequests) override + { + NB_OVERRIDE_PURE(addSequenceBatch, requestInfos, llmRequests); + } + std::optional removeSequence(tb::LlmRequest::RequestIdType requestId, tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) override @@ -393,6 +400,27 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("add_token", &BaseKVCacheManager::addToken, nb::call_guard()) .def("get_token_count", &BaseKVCacheManager::getTokenCount, nb::arg("request_id")) .def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard()) + .def( + "add_sequence_batch", + [](tbk::BaseKVCacheManager& self, nb::list requestInfosList, nb::list llmRequestsList) + { + // Marshal Python inputs while GIL is held. + std::vector> requestInfos; + std::vector> llmRequests; + requestInfos.reserve(nb::len(requestInfosList)); + llmRequests.reserve(nb::len(llmRequestsList)); + for (size_t i = 0; i < nb::len(requestInfosList); ++i) + { + auto info = nb::cast(requestInfosList[i]); + requestInfos.emplace_back(nb::cast(info[0]), + nb::cast(info[1]), nb::cast(info[2])); + llmRequests.push_back(std::ref(nb::cast(llmRequestsList[i]))); + } + // Release GIL only for the C++ call. + nb::gil_scoped_release release; + self.addSequenceBatch(requestInfos, llmRequests); + }, + nb::arg("request_infos"), nb::arg("llm_requests")) .def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard()) .def("pin_blocks", &BaseKVCacheManager::pinBlocks, nb::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 971ff8c5402..aebd8b96f07 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -619,6 +619,14 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): # wait for all pending work to finish before launching offload/onboarding/partial copy self.impl.sync_transfer_manager_with_buffer_manager() + # Collect first-chunk requests eligible for batch add_sequence. + # When block reuse is enabled, addSequenceBatch uses a two-phase + # claim-then-onboard strategy that prevents host offloading from + # evicting reusable blocks in the radix tree. + batch_request_infos = [] + batch_llm_requests = [] + batch_ctx_requests = [] + # allocate KV Cache for req in scheduled_batch.context_requests: req_beam_width = req.sampling_config.beam_width @@ -635,18 +643,41 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): else: if req.is_first_context_chunk and self._kv_connector_should_add_sequence( req): - self.impl.add_sequence(req.py_request_id, - req.prompt_len, req_beam_width, - req) - for _ in range(self.num_extra_kv_tokens): - self.impl.add_token(req.py_request_id) - for _ in range(get_draft_token_length(req)): - self.impl.add_token(req.py_request_id) - - if self.kv_connector_manager is not None: - block_ids = self.get_cache_indices(req) - self.kv_connector_manager.update_state_after_alloc( - req, block_ids) + if self.enable_block_reuse and not self.is_vswa: + # Batch path: two-phase claim-then-onboard + # (not supported for VSWA which needs multi-window addSequence) + batch_request_infos.append( + (req.py_request_id, req.prompt_len, + req_beam_width)) + batch_llm_requests.append(req) + batch_ctx_requests.append(req) + else: + self.impl.add_sequence(req.py_request_id, + req.prompt_len, + req_beam_width, req) + for _ in range(self.num_extra_kv_tokens): + self.impl.add_token(req.py_request_id) + for _ in range(get_draft_token_length(req)): + self.impl.add_token(req.py_request_id) + + if self.kv_connector_manager is not None: + block_ids = self.get_cache_indices(req) + self.kv_connector_manager.update_state_after_alloc( + req, block_ids) + + if batch_request_infos: + self.impl.add_sequence_batch(batch_request_infos, + batch_llm_requests) + for req in batch_ctx_requests: + for _ in range(self.num_extra_kv_tokens): + self.impl.add_token(req.py_request_id) + for _ in range(get_draft_token_length(req)): + self.impl.add_token(req.py_request_id) + + if self.kv_connector_manager is not None: + block_ids = self.get_cache_indices(req) + self.kv_connector_manager.update_state_after_alloc( + req, block_ids) # A request may change from `context_requests_chunking` to `context_requests_last_chunk` in `add_sequence` due to KV cache reuse, so we rebuild the context request lists here. scheduled_batch.reset_context_requests()