Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
bdb3791
[None][feat] Wire KVCacheBlock to UnifiedBlockTree, replacing mPrevBl…
SimengLiu-nv Mar 4, 2026
bd4810a
Address comments.
SimengLiu-nv Mar 5, 2026
27574b9
block allocation and reusing works for linear attention
VALLIS-NERIA Jan 28, 2026
3543bbe
copy states during context shifts
VALLIS-NERIA Jan 30, 2026
36aa474
fix corner cases
VALLIS-NERIA Feb 4, 2026
cd1a67b
temp stage: accuracy w/o reuse ok
VALLIS-NERIA Mar 2, 2026
94d4312
temp stage: accuracy with reuse ok
VALLIS-NERIA Mar 2, 2026
d885842
fix merge conflicts
VALLIS-NERIA Mar 9, 2026
603d822
Merge remote-tracking branch 'origin/main' into pr-11919
VALLIS-NERIA Mar 9, 2026
b398561
temporary stage
VALLIS-NERIA Mar 13, 2026
df7284a
fix multiple issues
VALLIS-NERIA Mar 14, 2026
ce9674a
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Mar 14, 2026
cab2412
use pre calculated buffers
VALLIS-NERIA Mar 14, 2026
a1889b8
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Mar 17, 2026
22e7fd2
scheduler support
VALLIS-NERIA Mar 18, 2026
6475692
FIFO placeholder management
VALLIS-NERIA Mar 18, 2026
3312fa9
remove debug prints in module/op level
VALLIS-NERIA Mar 18, 2026
9b73cbf
change memory layout to layer first
VALLIS-NERIA Mar 18, 2026
efbb815
fix scheduler
VALLIS-NERIA Mar 18, 2026
aa15395
auto choose mamba cache manager impl
VALLIS-NERIA Mar 19, 2026
5bfda48
format code
VALLIS-NERIA Mar 19, 2026
f9e2ad0
fix unhandled kFORCE_CHUNK enum in switch statement
VALLIS-NERIA Mar 20, 2026
1810dba
fix config of current implementation
VALLIS-NERIA Mar 20, 2026
cf50425
merge upstream main and resolve conflicts
VALLIS-NERIA Mar 20, 2026
4dd57bf
fix missing is_nemotron_hybrid/is_qwen3_hybrid imports
VALLIS-NERIA Mar 20, 2026
b4e54e7
remove some hacks
VALLIS-NERIA Mar 20, 2026
ee0b690
[Agent fix] restore block reuse defaults and fix AutoDeploy mamba_lay…
VALLIS-NERIA Mar 20, 2026
782c46f
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Mar 21, 2026
75b9438
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Mar 22, 2026
c27c351
revert to use old mambacachemanager as default
VALLIS-NERIA Mar 22, 2026
850bd66
[Agent fix] Remove debug prints, commented debug code, and tensor dum…
VALLIS-NERIA Mar 22, 2026
b0921fb
fix not mine unit tests
VALLIS-NERIA Mar 22, 2026
7f03f58
temporary disable my unit tests to run CI
VALLIS-NERIA Mar 22, 2026
d020bf6
Revert "revert to use old mambacachemanager as default"
VALLIS-NERIA Mar 22, 2026
98da518
only auto-deploy uses old mambacachemanager & fix beam search
VALLIS-NERIA Mar 22, 2026
eb0044d
use ceil div for head split
VALLIS-NERIA Mar 23, 2026
325e454
get rid of model_config
VALLIS-NERIA Mar 23, 2026
27ef0bf
[TRTLLM-10061][fix] Use ceil_div for head/size calculations in model_…
VALLIS-NERIA Mar 23, 2026
8620daa
[TRTLLM-10061][feat] Add stride support for conv1d and fused_sigmoid_…
VALLIS-NERIA Mar 23, 2026
398495f
fix memory usage and model_config check
VALLIS-NERIA Mar 23, 2026
41f1b77
Remove index bounds checking in h0_source store
VALLIS-NERIA Mar 23, 2026
1281d40
Merge remote-tracking branch 'fork/user/xiweny/ceil_div_model_config'…
VALLIS-NERIA Mar 23, 2026
96c62e7
Merge remote-tracking branch 'fork/user/xiweny/stride_support' into u…
VALLIS-NERIA Mar 23, 2026
1c83f1e
refine evictionpolicy
VALLIS-NERIA Mar 23, 2026
7dacd9e
refine mamba cache manager
VALLIS-NERIA Mar 24, 2026
ab3bc32
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Mar 24, 2026
b31dd85
clean up unnecessary chagnes
VALLIS-NERIA Mar 24, 2026
81eb415
fix
VALLIS-NERIA Mar 24, 2026
12d8dda
add tests for scheduler
VALLIS-NERIA Mar 24, 2026
5d3a46e
[TRTLLM-10061][feat] Add FORCE_CHUNK context chunking policy
VALLIS-NERIA Mar 24, 2026
730d6ea
Merge remote-tracking branch 'origin' into user/xiweny/force_chunk_po…
VALLIS-NERIA Mar 24, 2026
6ed49f3
add tests for scheduler
VALLIS-NERIA Mar 24, 2026
807e9d3
improve comments
VALLIS-NERIA Mar 24, 2026
45f0fa5
fix
VALLIS-NERIA Mar 24, 2026
5f54054
Merge branch 'main' into user/xiweny/force_chunk_policy
VALLIS-NERIA Mar 24, 2026
a8dea92
fix kvcache manager ut
VALLIS-NERIA Mar 24, 2026
4717171
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Mar 24, 2026
43469ec
Merge branch 'main' into user/xiweny/force_chunk_policy
VALLIS-NERIA Mar 24, 2026
422b80f
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Mar 25, 2026
bda1763
Merge branch 'main' into user/xiweny/force_chunk_policy
VALLIS-NERIA Mar 25, 2026
86b437a
clean c++ code
VALLIS-NERIA Mar 25, 2026
6f0de1b
Merge remote-tracking branch 'fork/user/xiweny/force_chunk_policy' in…
VALLIS-NERIA Mar 25, 2026
414c4d2
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Mar 25, 2026
bf5599d
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Apr 3, 2026
c1129e4
remove warning at exit
VALLIS-NERIA Apr 3, 2026
82b7049
rename
VALLIS-NERIA Apr 3, 2026
42cd76b
[TRTLLM-10061][fix] Address review items for linear attention hybrid …
VALLIS-NERIA Apr 3, 2026
d616e76
fix style
VALLIS-NERIA Apr 3, 2026
d8dbaf0
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Apr 3, 2026
162a5eb
[Agent fix] Add missing triton import in modeling_qwen3_next.py
VALLIS-NERIA Apr 3, 2026
18d8650
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Apr 4, 2026
7a1b3eb
[Agent fix] Remove extra blank line in cuda_graph_runner.py
VALLIS-NERIA Apr 4, 2026
9319bf7
refine tests
VALLIS-NERIA Apr 4, 2026
fe01292
[Agent fix] Add missing imports for ruff-legacy compliance (F821/F811)
VALLIS-NERIA Apr 4, 2026
ea36783
[Agent fix] Remove duplicate ids kwarg from parametrize_with_ids call
VALLIS-NERIA Apr 5, 2026
f40cc41
[Agent fix] Update test list entries for parametrized test_fp8
VALLIS-NERIA Apr 5, 2026
67a4a55
[Agent fix] Fix CppMambaHybridCacheManager.get_state_indices signature
VALLIS-NERIA Apr 5, 2026
cc52b20
[Agent fix] Use MixedMambaHybridCacheManager in test_mamba_transfer f…
VALLIS-NERIA Apr 5, 2026
9247acb
Merge remote-tracking branch 'origin/main' into user/xiweny/linear_re…
VALLIS-NERIA Apr 5, 2026
e56e95e
[Agent fix] Add missing spec_metadata parameter to Qwen3NextGatedDelt…
VALLIS-NERIA Apr 6, 2026
605c6af
[Agent fix] Remove duplicate Qwen3NextGatedDeltaNet from modeling_qwe…
VALLIS-NERIA Apr 6, 2026
ee653a6
fix silly AI, unify naming and test
VALLIS-NERIA Apr 6, 2026
d3c7589
WAR block save issue
VALLIS-NERIA Apr 6, 2026
e8e8520
address comments
VALLIS-NERIA Apr 6, 2026
079f8bf
fix attention DP sharding
VALLIS-NERIA Apr 7, 2026
0115a80
Merge branch 'main' into user/xiweny/linear_reuse_new
xinhe-nv Apr 7, 2026
67a3c57
Merge remote-tracking branch 'fork/user/xiweny/linear_reuse_new' into…
VALLIS-NERIA Apr 7, 2026
ad051e9
address commentes
VALLIS-NERIA Apr 7, 2026
9a34a49
fix the placeholder issue
VALLIS-NERIA Apr 7, 2026
daa6320
Update l0_gb200_multi_gpus.yml
VALLIS-NERIA Apr 8, 2026
866dfcc
address comments
VALLIS-NERIA Apr 8, 2026
3915d7d
address comments
VALLIS-NERIA Apr 9, 2026
95923ce
[None][feat] Enable mamba/linear attention cache reuse in scheduler (…
VALLIS-NERIA Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ struct LinearAttentionMetadata
// take a snapshot every `blockAlignment` blocks.
auto perBlockBytes = allRecurrentStatesBytes * numLayers;
auto numDynamicBlocks = (memoryBudget / perBlockBytes);
TLLM_LOG_INFO(
"Calculated max memory blocks for linear cache with recurrent states: memoryBudget=%zu, "
"perBlockBytes=%zu, numDynamicBlocks=%d",
memoryBudget, perBlockBytes, numDynamicBlocks);
return static_cast<SizeType32>(numDynamicBlocks);
}
TLLM_THROW("Unknown linear cache type");
Expand Down Expand Up @@ -356,10 +360,7 @@ class KVCacheBlock : public std::enable_shared_from_this<KVCacheBlock>
static BlockPtr createPlaceholder(IdType blockId, SizeType32 windowSize);

void detachDescendantsFromLookupTree();
//! \brief Detach all placeholder blocks in the previous-block chain from the lookup tree.
//! \details Walks upward via getPrevBlock() and calls detachFromLookupNode() on each
//! block that is a placeholder. Stops at the root (kCachedBlocksRootId).
void detachPreviousPlaceholdersFromLookupTree() const;

void freeBlockAndAllDescendants();

//! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of
Expand Down Expand Up @@ -520,6 +521,11 @@ class GenerationRequest
return mCacheBlockIds.at(windowSize);
}

[[nodiscard]] std::vector<std::vector<SizeType32>>& getCacheBlockIds(SizeType32 windowSize)
{
return mCacheBlockIds.at(windowSize);
}

[[nodiscard]] runtime::ITensor& getCacheBlockIndices(SizeType32 windowSize)
{
return *(mCacheBlockIndices.at(windowSize));
Expand Down
69 changes: 42 additions & 27 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,34 +494,9 @@ void KVCacheBlock::detachDescendantsFromLookupTree()
}
}

void KVCacheBlock::detachPreviousPlaceholdersFromLookupTree() const
{
BlockPtr current = getPrevBlock();
while (current != nullptr && current->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
if (!current->isPlaceholder())
{
return;
}
auto siblings = current->getNextBlocks();
for (auto const& [key, block] : siblings)
{
if (!block->isPlaceholder() && block.get() != this)
{
return;
}
}
BlockPtr prev = current->getPrevBlock();
current->detachFromLookupNode();
current->setPrevBlockInSeq(nullptr);
current = prev;
}
}

void KVCacheBlock::freeBlockAndAllDescendants()
{
detachDescendantsFromLookupTree();
detachPreviousPlaceholdersFromLookupTree();
detachFromLookupNode();
}

Expand Down Expand Up @@ -861,7 +836,9 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
mLogPrefix.c_str(), numPlaceholderBlocks, KVCacheBlock::kCachedBlocksRootId - 1 - numPlaceholderBlocks,
KVCacheBlock::kCachedBlocksRootId - 2);
TLLM_CHECK_WITH_INFO(isRecurrentState(),
"numPlaceholderBlocks > 0 is only supported for recurrent-state (kRecurrentStates) managers");
"numPlaceholderBlocks > 0 is only supported for recurrent-state (kRecurrentStates) managers, but this "
"manager has windowSize=%d and isSWA=%d",
windowSize, isSWA);
mAllPlaceholderBlocksById.resize(numPlaceholderBlocks + 2, nullptr);
for (SizeType32 i = 0; i < numPlaceholderBlocks; ++i)
{
Expand Down Expand Up @@ -1859,7 +1836,16 @@ bool WindowBlockManager::tryAllocatePlaceholderForLinearAttention(GenerationRequ
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto lastBlockId = lastBlockIds[beamIdx];
TLLM_CHECK(lastBlockId >= 0);
if (lastBlockId < 0)
{
auto const& blockIds = sequence.getCacheBlockIds(mWindowSize).at(0);
for (auto id : blockIds)
{
std::cout << id << " ";
}
std::cout << lastBlockId << std::endl;
TLLM_THROW("ERROR!");
}
TLLM_LOG_DEBUG("%s::allocateBlock - Swapping placeholder with last block %d for beam %d", mLogPrefix.c_str(),
lastBlockId, beamIdx);
auto lastBlock = getBlockById(lastBlockId);
Expand Down Expand Up @@ -2198,6 +2184,35 @@ void BlockManager::releaseLastBlock(GenerationRequest& sequence, SizeType32 wind

void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence)
{
if (isRecurrentState())
{
// In recurrent state, the last block always contains the current state and should not be released.
// Since the only caller of releaseLastBlock is speculative decoding rewinding, it only happens in decoding
// phase. We pop up the second last block instead, which is supposed to be a placeholder.
auto const requestId = sequence.getRequestId();
auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId);
TLLM_CHECK(allocatedBlocks.size() >= 2);
auto it = allocatedBlocks.rbegin();
auto& lastBlock = *it;
auto& secondLastBlock = *(++it);
TLLM_CHECK(secondLastBlock->isPlaceholder());
// Decrease ref count of the second last block (placeholder)
secondLastBlock->decRefCount();
if (!secondLastBlock->hasRefs())
{
mEvictionPolicy->releaseBlock(secondLastBlock, true);
}
// Remove the second last block from allocated blocks
allocatedBlocks.erase((++it).base());
// Remove stored block ids in sequence
auto beamWidth = sequence.getBeamWidth();
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
sequence.getCacheBlockIds(mWindowSize)[beamIdx].erase(
sequence.getCacheBlockIds(mWindowSize)[beamIdx].end() - 2);
}
return;
}
auto const requestId = sequence.getRequestId();
auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId);
auto it = allocatedBlocks.rbegin();
Expand Down
Empty file modified cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
100755 → 100644
Empty file.
5 changes: 2 additions & 3 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..memory_buffer_utils import Buffers
from ..metadata import KVCacheParams
from ..pyexecutor.mamba_cache_manager import MambaCacheManager
from ..pyexecutor.mamba_cache_manager import BaseMambaCacheManager
from ..pyexecutor.resource_manager import KVCacheManager, KVCacheManagerV2
from ..utils import get_model_extra_attrs

Expand Down Expand Up @@ -305,8 +305,7 @@ def _prepare_mamba_metadata(self):
return

if self.mamba_metadata is None:
if (self.kv_cache_manager is not None
and isinstance(self.kv_cache_manager, MambaCacheManager)):
if isinstance(self.kv_cache_manager, BaseMambaCacheManager):
from ..modules.mamba.mamba2_metadata import Mamba2Metadata
self.mamba_metadata = Mamba2Metadata(self.max_num_requests,
self.mamba_chunk_size)
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from ...._utils import get_free_port, mpi_rank, mpi_world_size
from ....mapping import Mapping
from ...distributed import Distributed
from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager
from ...pyexecutor.mamba_cache_manager import BaseMambaCacheManager
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
from ...pyexecutor.py_executor import PyExecutor
from ...pyexecutor.resource_manager import (
Expand Down Expand Up @@ -292,7 +292,7 @@ def _generate_dummy_request(
)

# check if it's a hybrid kv-cache manager
is_hybrid_cache = isinstance(kv_cache_manager, MambaHybridCacheManager)
is_hybrid_cache = isinstance(kv_cache_manager, BaseMambaCacheManager)

# check if we have a free page and free state available
if not kv_cache_manager.get_num_free_blocks():
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/shim/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tensorrt_llm.mapping import Mapping

from ...._utils import torch_dtype_to_binding
from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager
from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager, MixedMambaHybridCacheManager
from ...pyexecutor.resource_manager import KVCacheManager
from ..custom_ops.attention_interface import (
CausalConvResourceHandler,
Expand Down Expand Up @@ -511,7 +511,7 @@ def _create_and_assign_state_views(
num_managed_mamba_layers = mamba_params["mamba_num_layers"]

# Create the hybrid cache manager
manager = MambaHybridCacheManager(
manager = MixedMambaHybridCacheManager(
**mamba_params,
**kv_cache_kwargs,
)
Expand Down
56 changes: 44 additions & 12 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,32 @@
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar

import filelock
import torch
import transformers
from transformers.utils import HF_MODULES_CACHE

from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.config_utils import (
get_qwen3_hybrid_num_attention_layers, is_nemotron_hybrid, is_qwen3_hybrid,
load_pretrained_config)
get_qwen3_hybrid_num_attention_layers, is_hybrid_linear, is_nemotron_hybrid,
is_qwen3_hybrid, load_pretrained_config)
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig,
MoeLoadBalancerConfig)
KvCacheConfig, MoeLoadBalancerConfig)
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo

if TYPE_CHECKING:
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
from tensorrt_llm.llmapi.llm_args import (DecodingBaseConfig, LoraConfig,
SparseAttentionConfig,
SpeculativeConfig)

TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)


Expand Down Expand Up @@ -636,9 +641,12 @@ def _recursive_update_config(config: transformers.PretrainedConfig,
model_config._frozen = True
return model_config

def get_bindings_model_config(self,
tokens_per_block: Optional[int] = None
) -> "ModelConfigCpp":
def get_bindings_model_config(
self,
tokens_per_block: Optional[int] = None,
kv_cache_config: Optional[KvCacheConfig] = None,
spec_config: Optional['SpeculativeConfig'] = None,
) -> "ModelConfigCpp":
"""
This method is used to construct the bindings config for the model.
Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes
Expand Down Expand Up @@ -667,7 +675,8 @@ def ceil_div(a, b):

hidden_size = ceil_div(self.pretrained_config.hidden_size, attn_tp_size)
num_layers = self.pretrained_config.num_hidden_layers
num_attention_layers = self.get_num_attention_layers()
num_attention_layers = self.get_num_attention_layers(
kv_cache_config, spec_config)
if (self.spec_config is not None
and self.spec_config.spec_dec_mode.is_mtp_one_model()):
num_layers += self.spec_config.num_nextn_predict_layers
Expand All @@ -693,6 +702,7 @@ def ceil_div(a, b):

num_key_value_heads = getattr(self.pretrained_config,
"num_key_value_heads", num_heads)

if isinstance(num_key_value_heads, (list, tuple)):
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
num_kv_heads_per_layer = [
Expand Down Expand Up @@ -796,10 +806,32 @@ def get_layer_types(self) -> Optional[List[LayerTypeCpp]]:
else:
return None

def get_num_attention_layers(self):
if is_nemotron_hybrid(self.pretrained_config):
def get_num_attention_layers(
self,
kv_cache_config: Optional[KvCacheConfig] = None,
spec_config: Optional['SpeculativeConfig'] = None):
"""Return the number of layers that need KV cache blocks.

For hybrid models using the MixedMambaHybridCacheManager path
(TRTLLM_USE_CPP_MAMBA=1 for disagg), only attention layers need KV
cache blocks, so we return the attention-only count.

For the default CppMambaHybridCacheManager path (including speculative
decoding), both attention and mamba layers are managed in the unified
KV cache pool, so we return num_hidden_layers (all layers).
"""
use_disagg = os.environ.get('TRTLLM_USE_CPP_MAMBA', '0') == '1'
use_reuse = kv_cache_config is not None and kv_cache_config.enable_block_reuse

use_v1_mamba_manager = use_disagg
if is_hybrid_linear(
self.pretrained_config) and use_v1_mamba_manager and use_reuse:
logger.warning(
"Block reuse does not work with MTP or disagg for hybrid linear models"
)
if is_nemotron_hybrid(self.pretrained_config) and use_v1_mamba_manager:
return self.pretrained_config.hybrid_override_pattern.count("*")
elif is_qwen3_hybrid(self.pretrained_config):
elif is_qwen3_hybrid(self.pretrained_config) and use_v1_mamba_manager:
return get_qwen3_hybrid_num_attention_layers(self.pretrained_config)
else:
return self.pretrained_config.num_hidden_layers
7 changes: 6 additions & 1 deletion tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,14 @@ def forward(

state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size)
if num_prefills > 0:
ssm_states[state_indices_p] = torch.zeros(
# PyExecutor guarantees prefill requests are placed before decode requests
has_initial_states_p = has_initial_states[:num_prefills]
ssm_states[state_indices_p[~has_initial_states_p]] = torch.zeros(
(), dtype=ssm_states.dtype, device=ssm_states.device
)
conv_states[state_indices_p[~has_initial_states_p]] = torch.zeros(
(), dtype=conv_states.dtype, device=conv_states.device
)

is_target_verify = (
num_decodes > 0
Expand Down
3 changes: 1 addition & 2 deletions tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,8 @@ def forward(
is_target_verify = attn_metadata.kv_cache_manager.is_speculative(
) and spec_metadata is not None
if is_target_verify:
# Speculative decoding only supported with Python path
assert layer_cache is not None, \
"Speculative decoding requires Python MambaCacheManager"
"Speculative decoding requires mamba_layer_cache() support"
# TODO: support dynamic speculation, will add current_draft_len later [TRTLLM-10319]
draft_token_num = spec_metadata.max_draft_len + 1
intermediate_conv_states = layer_cache.intermediate_conv_window
Expand Down
Loading
Loading