Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,56 @@ class WindowBlockManager
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock);

//! \brief Per-request block allocation statistics from batch addSequence.
struct BatchSeqStats
{
SizeType32 prepopulatedLen{0};
SizeType32 allocTotalDelta{0};
SizeType32 allocNewDelta{0};
SizeType32 reusedDelta{0};
SizeType32 missedDelta{0};
};

//! \brief Result of Phase 1 (claim-only) of batch addSequence.
//! \details Holds matched blocks and prepared data so Phase 2 can proceed without
//! re-traversing the radix tree.
struct ClaimResult
{
struct ClaimedBlock
{
BlockPtr block;
SizeType32 numMatchedTokens; //!< tokens matched in this block
bool isPartialMatch;
bool needsCopy; //!< partial match on block with refs or non-leaf (needs getFreeBlock + copy in Phase 2)
bool isPlaceholder; //!< placeholder block (linear attention recurrent states)
};

std::vector<ClaimedBlock> claimedBlocks;
BlockPtr claimedCopySource; //!< unreferenced non-leaf partial-match source claimed to protect from eviction
SizeType32 totalMatchedTokens{0};
SizeType32 latestMatchingNonPlaceholderBlockIdx{-1};
SizeType32 numSharedContextBlocks{0};
SizeType32 numContextBlocks{0};
bool shareLastContextBlockAmongBeams{true};
std::vector<BlockKey> blockKeys;
std::vector<executor::RetentionPriorityAndDuration> perBlockRetentions;
executor::KvCacheTransferMode mode{executor::KvCacheTransferMode::DRAM};
std::string directory;
};

//! \brief Batch add sequences with two-phase claim-then-onboard under a single lock.
//! \details Phase 1 claims all matching blocks across all requests (protecting from eviction).
//! Phase 2 onboards host blocks and allocates non-matching blocks.
//! The mCachedBlocksRootMutex is held for the entire operation.
//! \param sequences Per-request GenerationRequest references (parallel with other vectors).
//! \param inputLengths Per-request effective input length.
//! \param numContextBlocksVec Per-request number of context blocks.
//! \param llmRequests Per-request LlmRequest references.
//! \return Per-request prepopulatedPromptLen.
[[nodiscard]] std::vector<BatchSeqStats> addSequenceBatch(std::vector<GenerationRequest*> const& sequences,
std::vector<SizeType32> const& inputLengths, std::vector<SizeType32> const& numContextBlocksVec,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests);

//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams);
Expand Down Expand Up @@ -1048,6 +1098,16 @@ class WindowBlockManager
bool shareLastContextBlockAmongBeams, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::string const& directory = "");

//! \brief Phase 1 (lock-free): Walk radix tree and claim matching blocks.
//! \details Caller must hold mCachedBlocksRootMutex.
[[nodiscard]] ClaimResult claimMatchingBlocks(
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);

//! \brief Phase 2 (lock-free): Onboard claimed host blocks and allocate non-matching blocks.
//! \details Caller must hold mCachedBlocksRootMutex.
[[nodiscard]] SizeType32 onboardAndAllocateBlocks(
GenerationRequest& sequence, LlmRequest& llmRequest, ClaimResult& claimResult);

//! \brief Free block and all it's descendants. This makes block a claimed leaf block.
void freeChildren(BlockPtr const& block);

Expand Down Expand Up @@ -1242,6 +1302,12 @@ class BlockManager
void addSequence(
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock);

//! \brief Batch add sequences forwarding to WindowBlockManager::addSequenceBatch.
[[nodiscard]] std::vector<WindowBlockManager::BatchSeqStats> addSequenceBatch(
std::vector<GenerationRequest*> const& sequences, std::vector<SizeType32> const& inputLengths,
std::vector<SizeType32> const& numContextBlocksVec,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests, SizeType32 windowSize);

void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize);

//! \brief According to request's current position, copy data from the last full block to the next block (ignoring
Expand Down Expand Up @@ -1732,6 +1798,15 @@ class BaseKVCacheManager
OptionalRef<LlmRequest> llmRequest = std::nullopt)
= 0;

//! \brief Batch add sequences with two-phase claim-then-onboard to prevent host offloading eviction.
//! \details Phase 1 claims all matching blocks across all requests (protecting them from eviction).
//! Phase 2 onboards host blocks and allocates non-matching blocks.
//! Requires block reuse enabled and single attention window.
virtual void addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests)
= 0;

[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false)
= 0;
Expand Down Expand Up @@ -2102,6 +2177,10 @@ class KVCacheManager : public BaseKVCacheManager
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
OptionalRef<LlmRequest> llmRequest = std::nullopt) override;

void addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests) override;

[[nodiscard]] std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false) override;

Expand Down
Loading
Loading