Skip to content

[None][fix] Batch addSequence with pre-claim to fix host offloading M…#12892

Open
liji-nv wants to merge 1 commit intoNVIDIA:mainfrom
liji-nv:fix/batch-addsequence-mnt-overflow-v2
Open

[None][fix] Batch addSequence with pre-claim to fix host offloading M…#12892
liji-nv wants to merge 1 commit intoNVIDIA:mainfrom
liji-nv:fix/batch-addsequence-mnt-overflow-v2

Conversation

@liji-nv
Copy link
Copy Markdown
Collaborator

@liji-nv liji-nv commented Apr 9, 2026

…NT overflow

When host offloading is enabled, onboarding a host block to GPU during addSequence can trigger eviction of other reusable host blocks from the radix tree. This causes actual KV cache reuse to be less than the scheduler estimated, leading to max_num_tokens (MNT) overflow assertions.

Add a new addSequenceBatch API that processes all first-chunk context requests in two phases:

  • Phase 1: Walk the radix tree and claimBlock() for all matching blocks across all requests. No onboarding, no allocation. This protects reusable blocks from eviction.
  • Phase 2: Onboard host blocks and allocate non-matching blocks. Since all reusable blocks are already claimed, evictions during onboarding cannot touch them.

On the Python side, replace the TOCTOU-prone revalidation loop (count_reusable_blocks + budget check) with a single batch call.

Summary by CodeRabbit

Release Notes

  • New Features
    • Added batch sequence onboarding API for improved KV cache management, enabling more efficient allocation and reuse of cache blocks for multiple concurrent sequences.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

…NT overflow

When host offloading is enabled, onboarding a host block to GPU during
addSequence can trigger eviction of other reusable host blocks from
the radix tree. This causes actual KV cache reuse to be less than the
scheduler estimated, leading to max_num_tokens (MNT) overflow assertions.

Add a new addSequenceBatch API that processes all first-chunk context
requests in two phases:
- Phase 1: Walk the radix tree and claimBlock() for all matching blocks
  across all requests. No onboarding, no allocation. This protects
  reusable blocks from eviction.
- Phase 2: Onboard host blocks and allocate non-matching blocks. Since
  all reusable blocks are already claimed, evictions during onboarding
  cannot touch them.

On the Python side, replace the TOCTOU-prone revalidation loop
(count_reusable_blocks + budget check) with a single batch call.

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
@liji-nv liji-nv requested a review from a team as a code owner April 9, 2026 14:01
@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Apr 9, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42543 [ run ] triggered by Bot. Commit: 2344609 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 9, 2026

📝 Walkthrough

Walkthrough

Introduces a new two-phase batch sequence onboarding API for KV cache management. Phase 1 claims matching cached blocks under lock; Phase 2 onboards and allocates remaining blocks. Adds corresponding methods to WindowBlockManager, BlockManager, BaseKVCacheManager, and KVCacheManager classes, along with Python bindings and integration into the resource manager's context allocation path when block reuse is enabled.

Changes

Cohort / File(s) Summary
KV Cache Manager Core
cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h, cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Added two-phase batch sequence onboarding: new ClaimResult struct and three methods (claimMatchingBlocks, onboardAndAllocateBlocks, addSequenceBatch) in WindowBlockManager; addSequenceBatch delegation in BlockManager; new pure virtual addSequenceBatch in BaseKVCacheManager and implementation in KVCacheManager with input validation, sequence creation, cache offset updates, and statistics tracking.
Python Bindings
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Added add_sequence_batch Python-exposed binding method on BaseKVCacheManager that marshals list arguments into C\+\+ tuples and reference vectors, forwarding to the underlying C\+\+ implementation.
Resource Manager Integration
tensorrt_llm/_torch/pyexecutor/resource_manager.py
Modified prepare_resources method to conditionally batch eligible first-context-chunk requests when block reuse is enabled, calling add_sequence_batch() once per batch instead of per-request add_sequence calls, while preserving the original path for non-block-reuse scenarios.

Sequence Diagram(s)

sequenceDiagram
    participant ResourceMgr as Resource Manager
    participant KVCacheMgr as KVCacheManager
    participant BlockMgr as BlockManager
    participant WindowBlockMgr as WindowBlockManager
    
    ResourceMgr->>KVCacheMgr: addSequenceBatch(requestInfos, llmRequests)
    
    activate KVCacheMgr
    KVCacheMgr->>KVCacheMgr: Validate constraints (block reuse enabled)
    KVCacheMgr->>KVCacheMgr: Create GenerationRequest sequences<br/>Compute inputLengths & numContextBlocksVec
    KVCacheMgr->>BlockMgr: addSequenceBatch(sequences, inputLengths,<br/>numContextBlocksVec, llmRequests, windowSize)
    
    activate BlockMgr
    BlockMgr->>WindowBlockMgr: addSequenceBatch(...)
    
    activate WindowBlockMgr
    WindowBlockMgr->>WindowBlockMgr: Acquire mCachedBlocksRootMutex
    
    loop For each sequence (Phase 1)
        WindowBlockMgr->>WindowBlockMgr: claimMatchingBlocks(sequence, inputLength,<br/>numContextBlocks, llmRequest)
        note over WindowBlockMgr: Radix-tree walk, claim blocks,<br/>update priorities
    end
    
    loop For each sequence (Phase 2)
        WindowBlockMgr->>WindowBlockMgr: onboardAndAllocateBlocks(sequence,<br/>llmRequest, claimResult)
        note over WindowBlockMgr: Onboard copied blocks, allocate<br/>non-matching & non-shared blocks,<br/>finalize prepopulated length
    end
    
    WindowBlockMgr->>WindowBlockMgr: Release mCachedBlocksRootMutex
    WindowBlockMgr-->>BlockMgr: Return prepopulated lengths
    deactivate WindowBlockMgr
    
    BlockMgr-->>KVCacheMgr: Return prepopulated lengths
    deactivate BlockMgr
    
    KVCacheMgr->>KVCacheMgr: Update per-window cache offsets
    KVCacheMgr->>KVCacheMgr: Set llmRequest prepopulated<br/>prompt lengths & clear reusable tokens
    KVCacheMgr-->>ResourceMgr: Return (batch complete)
    deactivate KVCacheMgr
    
    ResourceMgr->>ResourceMgr: Apply per-token add_token calls<br/>& KV connector updates
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title concisely summarizes the main change: batching addSequence with pre-claim to fix host offloading MNT overflow. It clearly identifies this as a fix and captures the key technical approach.
Description check ✅ Passed The PR description comprehensively explains the problem (host offloading eviction causing MNT overflow), the solution (two-phase batch API with pre-claiming), and implementation details on both C++ and Python sides. However, the description is provided outside the template structure with no explicit Test Coverage or Checklist sections completed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h`:
- Around line 1789-1796: The Python trampoline class PyKvCacheManager is missing
a pure-virtual override for the new C++ method addSequenceBatch, causing Python
subclasses not to dispatch; add an override method in PyKvCacheManager with the
signature matching addSequenceBatch that calls
NB_OVERRIDE_PURE(addSequenceBatch, requestInfos, llmRequests), and increment the
NB_TRAMPOLINE count for tbk::BaseKVCacheManager from 36 to 37 so the trampoline
table matches the new virtual. Ensure the override uses the exact parameter
types (std::vector<std::tuple<tb::LlmRequest::RequestIdType, SizeType32,
SizeType32>> const& requestInfos and
std::vector<std::reference_wrapper<tb::LlmRequest>> const& llmRequests) and that
NB_TRAMPOLINE is updated accordingly.

In `@cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp`:
- Around line 3408-3411: The current approach captures baseline totals in
numAllocTotalBlocksPre, numAllocNewBlocksPre, numReusedBlocksPre and
numMissedBlocksPre once for the whole batch and then computes per-request stats
by subtracting those global baselines, which causes earlier requests' work to be
attributed to later ones; instead compute per-request deltas either by
snapshotting the totals immediately before and after processing each individual
request (or at the exact request completion point) or by recording increments
during Phase 2 as each request performs allocations/reuses/misses; update the
logic in
updateAllocTotalPerRequest/updateAllocNewPerRequest/updateReusedPerRequest/updateMissedPerRequest
to use these per-request snapshots or Phase 2 counters rather than the
single-batch pre-snap variables so each request only accounts for its own
changes.
- Around line 3387-3449: Before calling mBlockManager.addSequenceBatch in
KVCacheManager::addSequenceBatch, enforce the same recurrent-state chunking
guard that WindowBlockManager::addSequence rejects: for each request, if its
kvCacheRetentionConfig indicates a recurrent-state retention mode and
inputLength != llmRequest.getPromptLen(), fail early (use TLLM_CHECK_WITH_INFO
or similar) with a message that batched recurrent-state requests must have
inputLength equal to the prompt length; perform this check using the
already-captured kvCacheRetentionConfig and llmRequest variables for each index
prior to invoking mBlockManager.addSequenceBatch.

In `@cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp`:
- Around line 396-413: The lambda bound to "add_sequence_batch" currently uses
nb::call_guard<nb::gil_scoped_release>() which drops the GIL for the entire body
and makes nb::len, list indexing, and nb::cast calls unsafe; change the
implementation so the GIL is held while marshalling the Python inputs into the
C++ locals (i.e., build requestInfos and llmRequests using nb::len, nb::cast,
nb::tuple, nb::list, std::ref) and only release the GIL immediately before
calling BaseKVCacheManager::addSequenceBatch (call self.addSequenceBatch under a
scoped GIL release). Ensure nb::call_guard<nb::gil_scoped_release>() is applied
only around the actual addSequenceBatch invocation or use an explicit
gil_scoped_release scope for that call.

In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py`:
- Around line 644-678: The batching path currently calls
impl.add_sequence_batch(batch_request_infos, batch_llm_requests) for block-reuse
requests which breaks VSWA; change the control so batching only happens when
self.is_vswa is False (i.e., gate the add_sequence_batch call with not
self.is_vswa) and if self.is_vswa is True, fall back to the per-request
impl.add_sequence(...) flow used earlier (including the subsequent
impl.add_token(...) loops and kv_connector_manager.update_state_after_alloc
calls for each req). Ensure the logic around enable_block_reuse,
batch_request_infos, batch_llm_requests and batch_ctx_requests remains
consistent so VSWA continues using the single-request path.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 21f2daa7-dcc7-468a-a725-8af6a3a2a2f3

📥 Commits

Reviewing files that changed from the base of the PR and between 3e942cc and 2344609.

📒 Files selected for processing (4)
  • cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py

Comment on lines +1789 to +1796
//! \brief Batch add sequences with two-phase claim-then-onboard to prevent host offloading eviction.
//! \details Phase 1 claims all matching blocks across all requests (protecting them from eviction).
//! Phase 2 onboards host blocks and allocates non-matching blocks.
//! Requires block reuse enabled and single attention window.
virtual void addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests)
= 0;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
rg -n -C2 'class PyKvCacheManager|NB_TRAMPOLINE\(tbk::BaseKVCacheManager|addSequenceBatch' \
  cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h \
  cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 4248


🏁 Script executed:

#!/bin/bash
# Check the header file around the new pure virtual
echo "=== NEW PURE VIRTUAL (lines 1788-1800) ==="
sed -n '1788,1800p' cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

# Check what class implements this virtual
echo -e "\n=== IMPLEMENTATION CLASS (lines 2160-2175) ==="
sed -n '2160,2175p' cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

# Check PyKvCacheManager class definition in nanobind
echo -e "\n=== PYKVCACHEMANAGER CLASS (lines 68-450) ==="
sed -n '68,450p' cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 17913


Add missing nanobind override for new addSequenceBatch pure virtual.

The header declares a new pure virtual addSequenceBatch(...) at lines 1793–1796, and the nanobind bindings file includes a Python-facing .def("add_sequence_batch", ...) wrapper at line 411 that calls self.addSequenceBatch(...). However, PyKvCacheManager lacks the required NB_OVERRIDE_PURE(addSequenceBatch, requestInfos, llmRequests) override and the NB_TRAMPOLINE count remains at 36 instead of 37.

Without this override, Python subclasses of BaseKVCacheManager cannot properly dispatch this method through the C++ virtual table.

Required nanobind fix

Add to PyKvCacheManager class in cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp:

    void addSequenceBatch(
        std::vector<std::tuple<tb::LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
        std::vector<std::reference_wrapper<tb::LlmRequest>> const& llmRequests) override
    {
        NB_OVERRIDE_PURE(addSequenceBatch, requestInfos, llmRequests);
    }

And update the trampoline declaration from:

    NB_TRAMPOLINE(tbk::BaseKVCacheManager, 36);

to:

    NB_TRAMPOLINE(tbk::BaseKVCacheManager, 37);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h` around lines 1789 -
1796, The Python trampoline class PyKvCacheManager is missing a pure-virtual
override for the new C++ method addSequenceBatch, causing Python subclasses not
to dispatch; add an override method in PyKvCacheManager with the signature
matching addSequenceBatch that calls NB_OVERRIDE_PURE(addSequenceBatch,
requestInfos, llmRequests), and increment the NB_TRAMPOLINE count for
tbk::BaseKVCacheManager from 36 to 37 so the trampoline table matches the new
virtual. Ensure the override uses the exact parameter types
(std::vector<std::tuple<tb::LlmRequest::RequestIdType, SizeType32, SizeType32>>
const& requestInfos and std::vector<std::reference_wrapper<tb::LlmRequest>>
const& llmRequests) and that NB_TRAMPOLINE is updated accordingly.

Comment on lines +1484 to +1498
if (partialMatch)
{
if (matchingBlock->hasRefs() || !matchingBlock->isLeaf())
{
// Block in use or not leaf — needs copy in Phase 2. Don't claim source (it has refs, safe).
claimed.needsCopy = true;
}
else
{
// Leaf with no refs — claim it now (freeLeafBlock + claimBlock, no eviction)
freeLeafBlock(matchingBlock);
mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority,
result.perBlockRetentions[bi].durationMs);
}
searchRoot = nullptr; // no matching for following blocks
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Claim unreferenced partial-match sources before Phase 2 copies.

With copyOnPartialReuse enabled, findMatchingBlock() can return a partial match that is not a leaf and also has no refs. This branch marks it needsCopy but leaves it in the free queue, so a later onboard/allocation in Phase 2 can evict the source before the copy happens. That recreates the race this batch API is trying to eliminate. Line 1488's “it has refs, safe” assumption only holds for half of the condition.

If the matched source has no refs, claim it here and release it after the copy completes.

Comment on lines +3387 to +3449
void KVCacheManager::addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests)
{
TLLM_CHECK(requestInfos.size() == llmRequests.size());
TLLM_CHECK_WITH_INFO(
!mBlockManager.isVariableWindow(), "addSequenceBatch does not support variable window attention");
if (requestInfos.empty())
{
return;
}
TLLM_CHECK_WITH_INFO(mEnableBlockReuse, "addSequenceBatch requires block reuse to be enabled");

auto const& [firstWindowSize, firstMetadata] = *mBlockManager.getWindowSizesMetadata().begin();

auto const n = requestInfos.size();

// --- Setup: create sequences, hold them, compute effective input length ---
std::vector<GenerationRequest*> sequences(n);
std::vector<SizeType32> inputLengths(n);
std::vector<SizeType32> numContextBlocksVec(n);
std::vector<SizeType32> numAllocTotalBlocksPre(n);
std::vector<SizeType32> numAllocNewBlocksPre(n);
std::vector<SizeType32> numReusedBlocksPre(n);
std::vector<SizeType32> numMissedBlocksPre(n);

for (size_t i = 0; i < n; ++i)
{
auto const& [requestId, inputLength, beamWidth] = requestInfos[i];
auto& llmRequest = llmRequests[i].get();

auto kvCacheRetentionConfig
= llmRequest.getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig());

auto const [seqIt, emplaceDone] = [&]
{
auto lck = std::scoped_lock(mSequencesMtx);
return mSequences.try_emplace(requestId, requestId, inputLength, beamWidth,
mBlockManager.getWindowSizesMetadata(), kvCacheRetentionConfig);
}();
TLLM_CHECK(emplaceDone);

sequences[i] = &seqIt->second;
numAllocTotalBlocksPre[i] = mBlockManager.getNumAllocTotalBlocks();
numAllocNewBlocksPre[i] = mBlockManager.getNumAllocNewBlocks();
numReusedBlocksPre[i] = mBlockManager.getNumReusedBlocks();
numMissedBlocksPre[i] = mBlockManager.getNumMissedBlocks();

if (!mBlockManager.isSequenceHeld(requestId))
{
mBlockManager.holdSequence(requestId);
}

auto const maxTokenNum = firstMetadata.maxTokenNum;
auto const temporaryAttentionWindow = firstMetadata.temporaryAttentionWindow;
inputLengths[i] = std::min(inputLength, maxTokenNum + temporaryAttentionWindow);
numContextBlocksVec[i] = tc::ceilDiv(inputLengths[i], getTokensPerBlock());
}

// --- Two-phase claim-then-onboard under a single lock ---
auto const prepopulatedLens
= mBlockManager.addSequenceBatch(sequences, inputLengths, numContextBlocksVec, llmRequests, firstWindowSize);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve the recurrent-state chunking guard in the batch path.

WindowBlockManager::addSequence() still rejects recurrent-state requests unless inputLength == llmRequest.getPromptLen() at Line 1951, but KVCacheManager::addSequenceBatch() never enforces that invariant before dispatching the batch. A batched first-chunk recurrent-state request can now enter a path that the single-request flow explicitly declares unsupported.

Please mirror the same guard before calling mBlockManager.addSequenceBatch(...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp` around lines 3387 - 3449,
Before calling mBlockManager.addSequenceBatch in
KVCacheManager::addSequenceBatch, enforce the same recurrent-state chunking
guard that WindowBlockManager::addSequence rejects: for each request, if its
kvCacheRetentionConfig indicates a recurrent-state retention mode and
inputLength != llmRequest.getPromptLen(), fail early (use TLLM_CHECK_WITH_INFO
or similar) with a message that batched recurrent-state requests must have
inputLength equal to the prompt length; perform this check using the
already-captured kvCacheRetentionConfig and llmRequest variables for each index
prior to invoking mBlockManager.addSequenceBatch.

Comment on lines +3408 to +3411
std::vector<SizeType32> numAllocTotalBlocksPre(n);
std::vector<SizeType32> numAllocNewBlocksPre(n);
std::vector<SizeType32> numReusedBlocksPre(n);
std::vector<SizeType32> numMissedBlocksPre(n);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These per-request block stats are accumulated across the whole batch.

All four baselines are captured before any request is processed, then each request subtracts from the final global totals. In a multi-request batch, every request inherits work done for earlier requests too, so updateAlloc*PerRequest() / updateReused*PerRequest() / updateMissed*PerRequest() are overstated.

Record per-request deltas during Phase 2, or snapshot/update around each request’s actual completion instead of around the whole batch.

Also applies to: 3462-3466

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp` around lines 3408 - 3411,
The current approach captures baseline totals in numAllocTotalBlocksPre,
numAllocNewBlocksPre, numReusedBlocksPre and numMissedBlocksPre once for the
whole batch and then computes per-request stats by subtracting those global
baselines, which causes earlier requests' work to be attributed to later ones;
instead compute per-request deltas either by snapshotting the totals immediately
before and after processing each individual request (or at the exact request
completion point) or by recording increments during Phase 2 as each request
performs allocations/reuses/misses; update the logic in
updateAllocTotalPerRequest/updateAllocNewPerRequest/updateReusedPerRequest/updateMissedPerRequest
to use these per-request snapshots or Phase 2 counters rather than the
single-batch pre-snap variables so each request only accounts for its own
changes.

Comment on lines +396 to +413
.def(
"add_sequence_batch",
[](tbk::BaseKVCacheManager& self, nb::list requestInfosList, nb::list llmRequestsList)
{
std::vector<std::tuple<tb::LlmRequest::RequestIdType, SizeType32, SizeType32>> requestInfos;
std::vector<std::reference_wrapper<tb::LlmRequest>> llmRequests;
requestInfos.reserve(nb::len(requestInfosList));
llmRequests.reserve(nb::len(llmRequestsList));
for (size_t i = 0; i < nb::len(requestInfosList); ++i)
{
auto info = nb::cast<nb::tuple>(requestInfosList[i]);
requestInfos.emplace_back(nb::cast<tb::LlmRequest::RequestIdType>(info[0]),
nb::cast<SizeType32>(info[1]), nb::cast<SizeType32>(info[2]));
llmRequests.push_back(std::ref(nb::cast<tb::LlmRequest&>(llmRequestsList[i])));
}
self.addSequenceBatch(requestInfos, llmRequests);
},
nb::arg("request_infos"), nb::arg("llm_requests"), nb::call_guard<nb::gil_scoped_release>())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

In nanobind, does nb::call_guardnb::gil_scoped_release()release the GIL before entering the bound lambda body, and is it safe to callnb::len, index nb::list, or nb::cast Python objects while that guard is active?

💡 Result:

Yes. nb::call_guard<T>() wraps the bound callable such that the guard is constructed before the target function/lambda is invoked and destroyed after it returns—nanobind’s docs show it as essentially:

T scope_guard;
return foo(args...);

So nb::call_guard<nb::gil_scoped_release>() releases the GIL for the duration of the bound function/lambda body. [1], [2]

While that guard is active, it is not safe to use Python-object operations like nb::len(...), indexing a nb::list (i.e., operator[] / __getitem__-style access), or nb::cast(...), because these are defined as doing the corresponding Python operations (and/or performing casts that touch Python refcounts / APIs), which require the GIL. Nanobind will even abort in some cases if refcounts are touched without the GIL. [3], [4]

If you need to do those operations inside such a function, reacquire the GIL in a smaller scope (e.g., nb::gil_scoped_acquire) around the Python interaction, or don’t release the GIL for that portion of the code. [2]

Sources: [1] nanobind “Call guards” docs, [2] nanobind “GIL Management” docs, [3] nanobind API docs for len() / list access semantics, [4] example nanobind crash message when changing refcounts without the GIL.


Don't drop the GIL before marshalling the Python lists.

nb::call_guard<nb::gil_scoped_release>() releases the GIL for the entire lambda body, making all calls to nb::len(...), list indexing, and nb::cast(...) unsafe. These operations require the GIL and can cause crashes or data races under concurrent Python activity. Keep the GIL while converting the Python inputs, and release it only around self.addSequenceBatch(...).

🐛 Suggested fix
         .def(
             "add_sequence_batch",
             [](tbk::BaseKVCacheManager& self, nb::list requestInfosList, nb::list llmRequestsList)
             {
+                if (nb::len(requestInfosList) != nb::len(llmRequestsList))
+                {
+                    throw nb::value_error("request_infos and llm_requests must have the same length");
+                }
+
                 std::vector<std::tuple<tb::LlmRequest::RequestIdType, SizeType32, SizeType32>> requestInfos;
                 std::vector<std::reference_wrapper<tb::LlmRequest>> llmRequests;
                 requestInfos.reserve(nb::len(requestInfosList));
                 llmRequests.reserve(nb::len(llmRequestsList));
                 for (size_t i = 0; i < nb::len(requestInfosList); ++i)
@@
                     requestInfos.emplace_back(nb::cast<tb::LlmRequest::RequestIdType>(info[0]),
                         nb::cast<SizeType32>(info[1]), nb::cast<SizeType32>(info[2]));
                     llmRequests.push_back(std::ref(nb::cast<tb::LlmRequest&>(llmRequestsList[i])));
                 }
-                self.addSequenceBatch(requestInfos, llmRequests);
+                nb::gil_scoped_release release;
+                self.addSequenceBatch(requestInfos, llmRequests);
             },
-            nb::arg("request_infos"), nb::arg("llm_requests"), nb::call_guard<nb::gil_scoped_release>())
+            nb::arg("request_infos"), nb::arg("llm_requests"))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp` around lines 396
- 413, The lambda bound to "add_sequence_batch" currently uses
nb::call_guard<nb::gil_scoped_release>() which drops the GIL for the entire body
and makes nb::len, list indexing, and nb::cast calls unsafe; change the
implementation so the GIL is held while marshalling the Python inputs into the
C++ locals (i.e., build requestInfos and llmRequests using nb::len, nb::cast,
nb::tuple, nb::list, std::ref) and only release the GIL immediately before
calling BaseKVCacheManager::addSequenceBatch (call self.addSequenceBatch under a
scoped GIL release). Ensure nb::call_guard<nb::gil_scoped_release>() is applied
only around the actual addSequenceBatch invocation or use an explicit
gil_scoped_release scope for that call.

Comment on lines 644 to +678
if req.is_first_context_chunk and self._kv_connector_should_add_sequence(
req):
self.impl.add_sequence(req.py_request_id,
req.prompt_len, req_beam_width,
req)
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

if self.kv_connector_manager is not None:
block_ids = self.get_cache_indices(req)
self.kv_connector_manager.update_state_after_alloc(
req, block_ids)
if self.enable_block_reuse:
batch_request_infos.append(
(req.py_request_id, req.prompt_len,
req_beam_width))
batch_llm_requests.append(req)
batch_ctx_requests.append(req)
else:
self.impl.add_sequence(req.py_request_id,
req.prompt_len,
req_beam_width, req)
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

if self.kv_connector_manager is not None:
block_ids = self.get_cache_indices(req)
self.kv_connector_manager.update_state_after_alloc(
req, block_ids)

if batch_request_infos:
self.impl.add_sequence_batch(batch_request_infos,
batch_llm_requests)
for req in batch_ctx_requests:
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

if self.kv_connector_manager is not None:
block_ids = self.get_cache_indices(req)
self.kv_connector_manager.update_state_after_alloc(
req, block_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep VSWA on the existing per-request path.

The new C++ API is documented as single-window only, but this branch now calls add_sequence_batch(...) for every block-reuse request. With VSWA enabled, that turns a previously supported path into an unsupported call. Gate batching with not self.is_vswa and let VSWA fall back to add_sequence(...).

💡 Suggested fix
-                        if self.enable_block_reuse:
+                        if self.enable_block_reuse and not self.is_vswa:
                             batch_request_infos.append(
                                 (req.py_request_id, req.prompt_len,
                                  req_beam_width))
                             batch_llm_requests.append(req)
                             batch_ctx_requests.append(req)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py` around lines 644 - 678,
The batching path currently calls impl.add_sequence_batch(batch_request_infos,
batch_llm_requests) for block-reuse requests which breaks VSWA; change the
control so batching only happens when self.is_vswa is False (i.e., gate the
add_sequence_batch call with not self.is_vswa) and if self.is_vswa is True, fall
back to the per-request impl.add_sequence(...) flow used earlier (including the
subsequent impl.add_token(...) loops and
kv_connector_manager.update_state_after_alloc calls for each req). Ensure the
logic around enable_block_reuse, batch_request_infos, batch_llm_requests and
batch_ctx_requests remains consistent so VSWA continues using the single-request
path.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42543 [ run ] completed with state FAILURE. Commit: 2344609
/LLM/main/L0_MergeRequest_PR pipeline #33280 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants