diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index fc5bf54c323..3d605d4a6f5 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -189,6 +189,10 @@ struct LinearAttentionMetadata // take a snapshot every `blockAlignment` blocks. auto perBlockBytes = allRecurrentStatesBytes * numLayers; auto numDynamicBlocks = (memoryBudget / perBlockBytes); + TLLM_LOG_INFO( + "Calculated max memory blocks for linear cache with recurrent states: memoryBudget=%zu, " + "perBlockBytes=%zu, numDynamicBlocks=%d", + memoryBudget, perBlockBytes, numDynamicBlocks); return static_cast(numDynamicBlocks); } TLLM_THROW("Unknown linear cache type"); @@ -356,10 +360,7 @@ class KVCacheBlock : public std::enable_shared_from_this static BlockPtr createPlaceholder(IdType blockId, SizeType32 windowSize); void detachDescendantsFromLookupTree(); - //! \brief Detach all placeholder blocks in the previous-block chain from the lookup tree. - //! \details Walks upward via getPrevBlock() and calls detachFromLookupNode() on each - //! block that is a placeholder. Stops at the root (kCachedBlocksRootId). - void detachPreviousPlaceholdersFromLookupTree() const; + void freeBlockAndAllDescendants(); //! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of @@ -520,6 +521,11 @@ class GenerationRequest return mCacheBlockIds.at(windowSize); } + [[nodiscard]] std::vector>& getCacheBlockIds(SizeType32 windowSize) + { + return mCacheBlockIds.at(windowSize); + } + [[nodiscard]] runtime::ITensor& getCacheBlockIndices(SizeType32 windowSize) { return *(mCacheBlockIndices.at(windowSize)); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index addbfbf822c..46bca9e5045 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -494,34 +494,9 @@ void KVCacheBlock::detachDescendantsFromLookupTree() } } -void KVCacheBlock::detachPreviousPlaceholdersFromLookupTree() const -{ - BlockPtr current = getPrevBlock(); - while (current != nullptr && current->getBlockId() != KVCacheBlock::kCachedBlocksRootId) - { - if (!current->isPlaceholder()) - { - return; - } - auto siblings = current->getNextBlocks(); - for (auto const& [key, block] : siblings) - { - if (!block->isPlaceholder() && block.get() != this) - { - return; - } - } - BlockPtr prev = current->getPrevBlock(); - current->detachFromLookupNode(); - current->setPrevBlockInSeq(nullptr); - current = prev; - } -} - void KVCacheBlock::freeBlockAndAllDescendants() { detachDescendantsFromLookupTree(); - detachPreviousPlaceholdersFromLookupTree(); detachFromLookupNode(); } @@ -861,7 +836,9 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind mLogPrefix.c_str(), numPlaceholderBlocks, KVCacheBlock::kCachedBlocksRootId - 1 - numPlaceholderBlocks, KVCacheBlock::kCachedBlocksRootId - 2); TLLM_CHECK_WITH_INFO(isRecurrentState(), - "numPlaceholderBlocks > 0 is only supported for recurrent-state (kRecurrentStates) managers"); + "numPlaceholderBlocks > 0 is only supported for recurrent-state (kRecurrentStates) managers, but this " + "manager has windowSize=%d and isSWA=%d", + windowSize, isSWA); mAllPlaceholderBlocksById.resize(numPlaceholderBlocks + 2, nullptr); for (SizeType32 i = 0; i < numPlaceholderBlocks; ++i) { @@ -1859,7 +1836,16 @@ bool WindowBlockManager::tryAllocatePlaceholderForLinearAttention(GenerationRequ for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto lastBlockId = lastBlockIds[beamIdx]; - TLLM_CHECK(lastBlockId >= 0); + if (lastBlockId < 0) + { + auto const& blockIds = sequence.getCacheBlockIds(mWindowSize).at(0); + for (auto id : blockIds) + { + std::cout << id << " "; + } + std::cout << lastBlockId << std::endl; + TLLM_THROW("ERROR!"); + } TLLM_LOG_DEBUG("%s::allocateBlock - Swapping placeholder with last block %d for beam %d", mLogPrefix.c_str(), lastBlockId, beamIdx); auto lastBlock = getBlockById(lastBlockId); @@ -2198,6 +2184,35 @@ void BlockManager::releaseLastBlock(GenerationRequest& sequence, SizeType32 wind void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence) { + if (isRecurrentState()) + { + // In recurrent state, the last block always contains the current state and should not be released. + // Since the only caller of releaseLastBlock is speculative decoding rewinding, it only happens in decoding + // phase. We pop up the second last block instead, which is supposed to be a placeholder. + auto const requestId = sequence.getRequestId(); + auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); + TLLM_CHECK(allocatedBlocks.size() >= 2); + auto it = allocatedBlocks.rbegin(); + auto& lastBlock = *it; + auto& secondLastBlock = *(++it); + TLLM_CHECK(secondLastBlock->isPlaceholder()); + // Decrease ref count of the second last block (placeholder) + secondLastBlock->decRefCount(); + if (!secondLastBlock->hasRefs()) + { + mEvictionPolicy->releaseBlock(secondLastBlock, true); + } + // Remove the second last block from allocated blocks + allocatedBlocks.erase((++it).base()); + // Remove stored block ids in sequence + auto beamWidth = sequence.getBeamWidth(); + for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) + { + sequence.getCacheBlockIds(mWindowSize)[beamIdx].erase( + sequence.getCacheBlockIds(mWindowSize)[beamIdx].end() - 2); + } + return; + } auto const requestId = sequence.getRequestId(); auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); auto it = allocatedBlocks.rbegin(); diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp old mode 100755 new mode 100644 diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 600f655bc51..439590997b9 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -22,7 +22,7 @@ from ..memory_buffer_utils import Buffers from ..metadata import KVCacheParams -from ..pyexecutor.mamba_cache_manager import MambaCacheManager +from ..pyexecutor.mamba_cache_manager import BaseMambaCacheManager from ..pyexecutor.resource_manager import KVCacheManager, KVCacheManagerV2 from ..utils import get_model_extra_attrs @@ -305,8 +305,7 @@ def _prepare_mamba_metadata(self): return if self.mamba_metadata is None: - if (self.kv_cache_manager is not None - and isinstance(self.kv_cache_manager, MambaCacheManager)): + if isinstance(self.kv_cache_manager, BaseMambaCacheManager): from ..modules.mamba.mamba2_metadata import Mamba2Metadata self.mamba_metadata = Mamba2Metadata(self.max_num_requests, self.mamba_chunk_size) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index ea336dab0f4..6533baa629f 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -51,7 +51,7 @@ from ...._utils import get_free_port, mpi_rank, mpi_world_size from ....mapping import Mapping from ...distributed import Distributed -from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager +from ...pyexecutor.mamba_cache_manager import BaseMambaCacheManager from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine from ...pyexecutor.py_executor import PyExecutor from ...pyexecutor.resource_manager import ( @@ -292,7 +292,7 @@ def _generate_dummy_request( ) # check if it's a hybrid kv-cache manager - is_hybrid_cache = isinstance(kv_cache_manager, MambaHybridCacheManager) + is_hybrid_cache = isinstance(kv_cache_manager, BaseMambaCacheManager) # check if we have a free page and free state available if not kv_cache_manager.get_num_free_blocks(): diff --git a/tensorrt_llm/_torch/auto_deploy/shim/interface.py b/tensorrt_llm/_torch/auto_deploy/shim/interface.py index 8c60753aaab..f073d73cd38 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/interface.py @@ -11,7 +11,7 @@ from tensorrt_llm.mapping import Mapping from ...._utils import torch_dtype_to_binding -from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager +from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager, MixedMambaHybridCacheManager from ...pyexecutor.resource_manager import KVCacheManager from ..custom_ops.attention_interface import ( CausalConvResourceHandler, @@ -511,7 +511,7 @@ def _create_and_assign_state_views( num_managed_mamba_layers = mamba_params["mamba_num_layers"] # Create the hybrid cache manager - manager = MambaHybridCacheManager( + manager = MixedMambaHybridCacheManager( **mamba_params, **kv_cache_kwargs, ) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index f9988fc9442..b49b9617dce 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -4,27 +4,32 @@ import tempfile from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Generic, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar import filelock import torch import transformers from transformers.utils import HF_MODULES_CACHE -from tensorrt_llm import logger from tensorrt_llm._torch.pyexecutor.config_utils import ( - get_qwen3_hybrid_num_attention_layers, is_nemotron_hybrid, is_qwen3_hybrid, - load_pretrained_config) + get_qwen3_hybrid_num_attention_layers, is_hybrid_linear, is_nemotron_hybrid, + is_qwen3_hybrid, load_pretrained_config) from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding from tensorrt_llm.bindings import LayerType as LayerTypeCpp from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig, - MoeLoadBalancerConfig) + KvCacheConfig, MoeLoadBalancerConfig) from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo +if TYPE_CHECKING: + from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp + from tensorrt_llm.llmapi.llm_args import (DecodingBaseConfig, LoraConfig, + SparseAttentionConfig, + SpeculativeConfig) + TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) @@ -636,9 +641,12 @@ def _recursive_update_config(config: transformers.PretrainedConfig, model_config._frozen = True return model_config - def get_bindings_model_config(self, - tokens_per_block: Optional[int] = None - ) -> "ModelConfigCpp": + def get_bindings_model_config( + self, + tokens_per_block: Optional[int] = None, + kv_cache_config: Optional[KvCacheConfig] = None, + spec_config: Optional['SpeculativeConfig'] = None, + ) -> "ModelConfigCpp": """ This method is used to construct the bindings config for the model. Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes @@ -667,7 +675,8 @@ def ceil_div(a, b): hidden_size = ceil_div(self.pretrained_config.hidden_size, attn_tp_size) num_layers = self.pretrained_config.num_hidden_layers - num_attention_layers = self.get_num_attention_layers() + num_attention_layers = self.get_num_attention_layers( + kv_cache_config, spec_config) if (self.spec_config is not None and self.spec_config.spec_dec_mode.is_mtp_one_model()): num_layers += self.spec_config.num_nextn_predict_layers @@ -693,6 +702,7 @@ def ceil_div(a, b): num_key_value_heads = getattr(self.pretrained_config, "num_key_value_heads", num_heads) + if isinstance(num_key_value_heads, (list, tuple)): # Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models) num_kv_heads_per_layer = [ @@ -796,10 +806,32 @@ def get_layer_types(self) -> Optional[List[LayerTypeCpp]]: else: return None - def get_num_attention_layers(self): - if is_nemotron_hybrid(self.pretrained_config): + def get_num_attention_layers( + self, + kv_cache_config: Optional[KvCacheConfig] = None, + spec_config: Optional['SpeculativeConfig'] = None): + """Return the number of layers that need KV cache blocks. + + For hybrid models using the MixedMambaHybridCacheManager path + (TRTLLM_USE_CPP_MAMBA=1 for disagg), only attention layers need KV + cache blocks, so we return the attention-only count. + + For the default CppMambaHybridCacheManager path (including speculative + decoding), both attention and mamba layers are managed in the unified + KV cache pool, so we return num_hidden_layers (all layers). + """ + use_disagg = os.environ.get('TRTLLM_USE_CPP_MAMBA', '0') == '1' + use_reuse = kv_cache_config is not None and kv_cache_config.enable_block_reuse + + use_v1_mamba_manager = use_disagg + if is_hybrid_linear( + self.pretrained_config) and use_v1_mamba_manager and use_reuse: + logger.warning( + "Block reuse does not work with MTP or disagg for hybrid linear models" + ) + if is_nemotron_hybrid(self.pretrained_config) and use_v1_mamba_manager: return self.pretrained_config.hybrid_override_pattern.count("*") - elif is_qwen3_hybrid(self.pretrained_config): + elif is_qwen3_hybrid(self.pretrained_config) and use_v1_mamba_manager: return get_qwen3_hybrid_num_attention_layers(self.pretrained_config) else: return self.pretrained_config.num_hidden_layers diff --git a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py index 88e9aa0c2ca..d32421b33c1 100644 --- a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py @@ -763,9 +763,14 @@ def forward( state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size) if num_prefills > 0: - ssm_states[state_indices_p] = torch.zeros( + # PyExecutor guarantees prefill requests are placed before decode requests + has_initial_states_p = has_initial_states[:num_prefills] + ssm_states[state_indices_p[~has_initial_states_p]] = torch.zeros( (), dtype=ssm_states.dtype, device=ssm_states.device ) + conv_states[state_indices_p[~has_initial_states_p]] = torch.zeros( + (), dtype=conv_states.dtype, device=conv_states.device + ) is_target_verify = ( num_decodes > 0 diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 3cbd88f4337..a5043c8621e 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -383,9 +383,8 @@ def forward( is_target_verify = attn_metadata.kv_cache_manager.is_speculative( ) and spec_metadata is not None if is_target_verify: - # Speculative decoding only supported with Python path assert layer_cache is not None, \ - "Speculative decoding requires Python MambaCacheManager" + "Speculative decoding requires mamba_layer_cache() support" # TODO: support dynamic speculation, will add current_draft_len later [TRTLLM-10319] draft_token_num = spec_metadata.max_draft_len + 1 intermediate_conv_states = layer_cache.intermediate_conv_window diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index b000427620d..32040b44736 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -29,14 +29,14 @@ from ..model_config import ModelConfig from ..speculative import (get_num_extra_kv_tokens, get_num_spec_layers, get_spec_decoder, should_use_separate_draft_kv_cache) -from .config_utils import (get_qwen3_hybrid_layer_masks, is_mla, - is_nemotron_hybrid, is_qwen3_hybrid) +from .config_utils import (get_qwen3_hybrid_layer_masks, is_hybrid_linear, + is_mla, is_nemotron_hybrid, is_qwen3_hybrid) from .dwdp import DwdpManager from .guided_decoder import GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver from .llm_request import ExecutorResponse -from .mamba_cache_manager import MambaHybridCacheManager +from .mamba_cache_manager import BaseMambaCacheManager, MambaHybridCacheManager from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor from .resource_manager import (KVCacheManager, KVCacheManagerV2, @@ -62,7 +62,7 @@ def get_kv_cache_manager_cls(model_config: ModelConfig, sparse_attn_config = model_config.sparse_attention_config if sparse_attn_config is not None: return get_sparse_attn_kv_cache_manager(sparse_attn_config) - elif is_nemotron_hybrid(config) or is_qwen3_hybrid(config): + elif is_hybrid_linear(config): return MambaHybridCacheManager else: return KVCacheManagerV2 if kv_cache_config.use_kv_cache_manager_v2 else KVCacheManager @@ -961,7 +961,7 @@ def _create_kv_cache_manager( # - If layer_mask[i] is True, include layer i # - For layers beyond hybrid_override_pattern, treat them as attention layers pattern_len = len(config.hybrid_override_pattern) - hybrid_layer_mask = [] + full_attention_layer_mask = [] mamba_layer_mask = [] for i, include in enumerate(layer_mask): if i < pattern_len: @@ -972,13 +972,14 @@ def _create_kv_cache_manager( # Beyond the pattern (e.g., MTP/draft layers), treat as attention-only is_attention = True is_mamba = False - hybrid_layer_mask.append(is_attention and include) + full_attention_layer_mask.append(is_attention and include) mamba_layer_mask.append(is_mamba and include) - num_layers = sum(hybrid_layer_mask) + num_full_attention_layers = sum(full_attention_layer_mask) mamba_num_layers = sum(mamba_layer_mask) else: - num_layers = config.hybrid_override_pattern.count("*") - hybrid_layer_mask = [ + num_full_attention_layers = config.hybrid_override_pattern.count( + "*") + full_attention_layer_mask = [ char == "*" for char in config.hybrid_override_pattern ] mamba_num_layers = config.hybrid_override_pattern.count("M") @@ -994,9 +995,9 @@ def _create_kv_cache_manager( from ..speculative.utils import get_num_spec_layers num_spec_layers = get_num_spec_layers(spec_config) if num_spec_layers > 0: - hybrid_layer_mask.extend([True] * num_spec_layers) + full_attention_layer_mask.extend([True] * num_spec_layers) mamba_layer_mask.extend([False] * num_spec_layers) - num_layers += num_spec_layers + num_full_attention_layers += num_spec_layers kv_cache_manager = kv_cache_manager_cls( # mamba cache parameters config.ssm_state_size, @@ -1012,8 +1013,8 @@ def _create_kv_cache_manager( # kv cache parameters kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, - num_layers=num_layers, - layer_mask=hybrid_layer_mask, + num_layers=num_full_attention_layers, + layer_mask=full_attention_layer_mask, num_kv_heads=per_layer_num_kv_heads, head_dim=head_dim, tokens_per_block=tokens_per_block, @@ -1034,9 +1035,9 @@ def _create_kv_cache_manager( raise NotImplementedError( "Connector manager is not supported for MambaHybridCacheManager." ) - hybrid_layer_mask, mamba_layer_mask = get_qwen3_hybrid_layer_masks( + full_attention_layer_mask, mamba_layer_mask = get_qwen3_hybrid_layer_masks( config) - # For hybrid models, hybrid_layer_mask is always passed as + # For hybrid models, full_attention_layer_mask is always passed as # layer_mask to KVCacheManager, which means get_pp_layers # sees a non-None layer_mask and won't auto-add spec layers. # Extend the masks here to include MTP spec layers (full @@ -1045,9 +1046,9 @@ def _create_kv_cache_manager( from ..speculative.utils import get_num_spec_layers num_spec_layers = get_num_spec_layers(spec_config) if num_spec_layers > 0: - hybrid_layer_mask.extend([True] * num_spec_layers) + full_attention_layer_mask.extend([True] * num_spec_layers) mamba_layer_mask.extend([False] * num_spec_layers) - num_layers = sum(hybrid_layer_mask) + num_full_attention_layers = sum(full_attention_layer_mask) num_mamba_layers = sum(mamba_layer_mask) kv_cache_manager = kv_cache_manager_cls( # mamba cache parameters @@ -1064,8 +1065,8 @@ def _create_kv_cache_manager( # kv cache parameters kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, - num_layers=num_layers, - layer_mask=hybrid_layer_mask, + num_layers=num_full_attention_layers, + layer_mask=full_attention_layer_mask, num_kv_heads=per_layer_num_kv_heads, head_dim=head_dim, tokens_per_block=tokens_per_block, @@ -1081,7 +1082,9 @@ def _create_kv_cache_manager( # NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_for_vswa in KVCahceManager is_vswa = is_vswa_enabled(kv_cache_config) binding_model_config = _model_config.get_bindings_model_config( - tokens_per_block=tokens_per_block) if is_vswa else None + tokens_per_block=tokens_per_block, + kv_cache_config=kv_cache_config, + spec_config=spec_config) if is_vswa else None kv_cache_manager = kv_cache_manager_cls( kv_cache_config, @@ -1363,7 +1366,7 @@ def create_py_executor_instance( # For hybrid models, this has both impl and mamba_impl mamba_cache_manager = None - if isinstance(kv_cache_manager, MambaHybridCacheManager): + if isinstance(kv_cache_manager, BaseMambaCacheManager): mamba_cache_manager = kv_cache_manager kv_cache_transceiver = create_kv_cache_transceiver( diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 9c3b4c37560..11b7a6160c7 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -5,6 +5,10 @@ import transformers +def is_hybrid_linear(config): + return is_nemotron_hybrid(config) or is_qwen3_hybrid(config) + + def is_nemotron_hybrid(config): if hasattr(config, "hybrid_override_pattern" ) and config.hybrid_override_pattern is not None and len( diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index 94a317b54bb..a0f232b797d 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Union @@ -24,12 +26,16 @@ if TYPE_CHECKING: from tensorrt_llm._torch.attention_backend.interface import \ AttentionMetadata + from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._torch.pyexecutor.resource_manager import ( BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers) from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests -from tensorrt_llm._utils import prefer_pinned, torch_dtype_to_binding +from tensorrt_llm._utils import (nvtx_range, prefer_pinned, + torch_dtype_to_binding) +from tensorrt_llm.bindings.internal.batch_manager import ( + LinearAttentionMetadata, LinearCacheType) from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -58,8 +64,56 @@ def use_cpp_mamba_cache_manager() -> bool: return os.environ.get('TRTLLM_USE_CPP_MAMBA', '0') == '1' +class BaseMambaCacheManager(ABC): + """Abstract interface for accessing mamba/recurrent state caches.""" + + @abstractmethod + def get_state_indices(self, *args, **kwargs) -> torch.Tensor: + """Return slot indices of each request. + + Shape: [max_batch_size] + """ + ... + + @abstractmethod + def get_conv_states(self, layer_idx: int) -> torch.Tensor: + """Return conv states for specific layer. + + Shape: [slot_size, conv_dim, d_conv - 1] + """ + ... + + @abstractmethod + def get_ssm_states(self, layer_idx: int) -> torch.Tensor: + """Return SSM states for specific layer. + + Shape: [slot_size, num_heads, head_dim, d_state] + """ + ... + + @abstractmethod + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + ... + + @abstractmethod + def is_speculative(self) -> bool: + ... + + @abstractmethod + def mamba_layer_cache( + self, layer_idx: int + ) -> Union['PythonMambaCacheManager.State', + 'PythonMambaCacheManager.SpeculativeState', None]: + ... + + class CppMambaCacheManager(BaseResourceManager): - """C++ backed Mamba cache manager using RnnStateManager bindings.""" + """Mamba state manager backed by the C++ RnnStateManager bindings. + + Manages only mamba states (conv + SSM). Used when TRTLLM_USE_CPP_MAMBA=1, + which is required for disaggregated serving deployments. + Does not support speculative decoding. + """ def __init__( self, @@ -165,6 +219,11 @@ def shutdown(self): class PythonMambaCacheManager(BaseResourceManager): + """Pure-Python mamba state manager with speculative decoding support. + + Manages only mamba states (conv + SSM) using PyTorch tensors on GPU. + Supports caching intermediate states for speculative decoding verification. + """ @dataclass(frozen=True, kw_only=True) class State: @@ -491,7 +550,13 @@ def update_mamba_states(self, attn_metadata: "AttentionMetadata", conv_states[:, state_indices_d, :] = accepted_conv_state -class MambaCacheManager(BaseResourceManager): +class MambaCacheManager(BaseResourceManager, BaseMambaCacheManager): + """Facade for standalone mamba state management (no KV cache). + + Delegates to CppMambaCacheManager (when TRTLLM_USE_CPP_MAMBA=1, required + for disaggregated serving) or PythonMambaCacheManager (default, supports + speculative decoding). + """ def __init__( self, @@ -617,7 +682,13 @@ def update_mamba_states(self, attn_metadata: "AttentionMetadata", self._impl.update_mamba_states(attn_metadata, num_accepted_tokens) -class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): +class MixedMambaHybridCacheManager(KVCacheManager, MambaCacheManager): + """Hybrid cache manager combining separate KVCacheManager and MambaCacheManager. + + Manages KV cache and mamba states in independent pools. Used for + speculative decoding or disaggregated serving (via CppMambaCacheManager). + Does not support block reuse / prefix caching for mamba states. + """ def __init__( self, @@ -698,9 +769,9 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): MambaCacheManager.prepare_resources(self, scheduled_batch) KVCacheManager.prepare_resources(self, scheduled_batch) - def free_resources(self, request: LlmRequest): + def free_resources(self, request: LlmRequest, pin_on_release: bool = False): MambaCacheManager.free_resources(self, request) - KVCacheManager.free_resources(self, request) + KVCacheManager.free_resources(self, request, pin_on_release) def add_dummy_requests(self, request_ids: List[int], **kwargs): MambaCacheManager.add_dummy_requests(self, request_ids) @@ -721,3 +792,532 @@ def update_mamba_states(self, attn_metadata: "AttentionMetadata", num_accepted_tokens: torch.Tensor): MambaCacheManager.update_mamba_states(self, attn_metadata, num_accepted_tokens) + + +def calc_context_stop_positions(prompt_len: int, + tokens_per_block: int, + mamba_state_cache_interval: int, + save_last_snapshot: bool = False) -> list[int]: + """Compute token positions at which mamba state snapshots should be saved. + + Returns positions spaced by ``mamba_state_cache_interval`` plus the final + prompt length (and optionally the last block-aligned position). + """ + stop_positions = list( + range(mamba_state_cache_interval, prompt_len, + mamba_state_cache_interval)) + last_ckpt = prompt_len // tokens_per_block * tokens_per_block + if save_last_snapshot and (last_ckpt not in stop_positions): + stop_positions.append(last_ckpt) + if prompt_len not in stop_positions: + stop_positions.append(prompt_len) + return stop_positions + + +class CppMambaHybridCacheManager(KVCacheManager, BaseMambaCacheManager): + """Hybrid cache manager storing mamba states inside the KVCacheManager pool. + + Both KV cache blocks and recurrent state blocks are managed by the unified + C++ KVCacheManager, enabling block reuse / prefix caching across attention + and mamba layers. This is the default hybrid manager. + + Speculative decoding is supported via separate intermediate state tensors + allocated outside the unified pool. Disaggregated serving is not supported. + """ + + def __init__( + self, + # mamba cache parameters + mamba_d_state: int, + mamba_d_conv: int, + mamba_num_heads: int, + mamba_n_groups: int, + mamba_head_dim: int, + mamba_num_layers: int, + mamba_layer_mask: List[bool], + mamba_cache_dtype: torch.dtype, + mamba_ssm_cache_dtype: torch.dtype, + kv_cache_config: KvCacheConfig, + kv_cache_type: CacheTypeCpp, + *, + num_layers: int, + num_kv_heads: Union[int, List[Optional[int]]], + head_dim: int, + tokens_per_block: int, + # Note that max_seq_len is not necessarily equal to kv_cache_config.num_tokens. + # It's derived from the model's BuildConfig for consistency with the C++ backend. + max_seq_len: int, + max_batch_size: int, + mapping: Mapping, + dtype: DataType = DataType.HALF, + spec_config: Optional["DecodingBaseConfig"] = None, + layer_mask: Optional[List[bool]] = None, + is_estimating_kv_cache: bool = False, + **kwargs, + ) -> None: + + print(f"mamba_num_layers: {mamba_num_layers}") + print(f"mamba_layer_mask: {mamba_layer_mask}") + print(f"num_layers: {num_layers}") + print(f"layer_mask: {layer_mask}") + + if mamba_num_layers > 0: + self.mamba_pp_layers, _ = get_pp_layers( + mamba_num_layers, + mapping, + layer_mask=mamba_layer_mask, + ) + else: + # No mamba layers on this rank — skip the get_pp_layers fallback + # that would insert a fake layer 0. + self.mamba_pp_layers = [] + + # Derive ssm_state_shape and conv_state_shape from mamba params (same as MambaCacheManager) + tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1 + d_inner = mamba_head_dim * mamba_num_heads + conv_dim = d_inner + 2 * mamba_n_groups * mamba_d_state + nheads = mamba_num_heads + assert nheads % tp_size == 0, "mamba_num_heads must be divisible by tp_size" + assert conv_dim % tp_size == 0, "conv_dim must be divisible by tp_size" + conv_dim = conv_dim // tp_size + nheads = nheads // tp_size + self.conv_state_shape = [conv_dim, mamba_d_conv - 1] + self.ssm_state_shape = [nheads, mamba_head_dim, mamba_d_state] + self.ssm_state_dtype = mamba_ssm_cache_dtype + self.conv_state_dtype = mamba_cache_dtype + self.ssm_count = math.prod(self.ssm_state_shape) + self.conv_count = math.prod(self.conv_state_shape) + self.ssm_bytes = self.ssm_count * self.ssm_state_dtype.itemsize + self.conv_bytes = self.conv_count * self.conv_state_dtype.itemsize + # round conv_bytes to 1KB + self.conv_bytes = ((self.conv_bytes + 1023) // 1024) * 1024 + + total_bytes = self.ssm_bytes + self.conv_bytes + if total_bytes % self.ssm_state_dtype.itemsize != 0: + raise RuntimeError( + f"Total state bytes ({total_bytes}) not divisible by " + f"ssm_state_dtype size ({self.ssm_state_dtype.itemsize})") + if total_bytes % self.conv_state_dtype.itemsize != 0: + raise RuntimeError( + f"Total state bytes ({total_bytes}) not divisible by " + f"conv_state_dtype size ({self.conv_state_dtype.itemsize})") + if self.ssm_bytes % self.conv_state_dtype.itemsize != 0: + raise RuntimeError( + f"SSM state bytes ({self.ssm_bytes}) not divisible by " + f"conv_state_dtype size ({self.conv_state_dtype.itemsize})") + if self.mamba_pp_layers: + self.linear_attention_metadata = LinearAttentionMetadata() + self.linear_attention_metadata.cache_type = LinearCacheType.RECURRENT_STATES.value + self.linear_attention_metadata.all_recurrent_states_bytes = self.ssm_bytes + self.conv_bytes + self.linear_attention_metadata.states_snapshot_interval = kv_cache_config.mamba_state_cache_interval + else: + self.linear_attention_metadata = None + kv_cache_config = kv_cache_config.model_copy(deep=True) + if kv_cache_config.enable_partial_reuse: + logger.warning( + "Partial reuse is not supported for mamba hybrid models, disabling partial reuse" + ) + kv_cache_config.enable_partial_reuse = False + + full_attention_layer_mask = layer_mask.copy() + + kv_cache_config.max_attention_window = [] + layer_mask = [ + mamba_layer_mask[i] or full_attention_layer_mask[i] + for i in range(len(mamba_layer_mask)) + ] + for i in range(len(layer_mask)): + if layer_mask[i]: + kv_cache_config.max_attention_window.append( + LinearCacheType.RECURRENT_STATES. + value if mamba_layer_mask[i] else max_seq_len) + # pass remaining arguments to super class + super().__init__( + kv_cache_config, + kv_cache_type, + num_layers=mamba_num_layers + num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=dtype, + spec_config=spec_config, + layer_mask=layer_mask, + is_estimating_kv_cache=is_estimating_kv_cache, + linear_attention_metadata=self.linear_attention_metadata, + ) + self.mamba_layer_offsets = {} + for idx, layer_id in enumerate(self.mamba_pp_layers): + self.mamba_layer_offsets[layer_id] = idx + self.num_mamba_layers = mamba_num_layers + self.host_block_offsets = torch.zeros([ + self.impl.num_pools, self.max_batch_size, 2, self.max_blocks_per_seq + ], + dtype=torch.int32, + device="cpu") + self.requests = [] + if self.mamba_pp_layers: + self.recurrent_states_pool_index = self.kv_cache_pool_mapping[ + self.layer_offsets[self.mamba_pp_layers[0]]][0] + self._setup_states_views() + self.cuda_state_indices = torch.zeros([self.max_batch_size], + dtype=torch.int32, + device="cuda") + self.kv_cache_config = kv_cache_config + + self.is_estimating_kv_cache = is_estimating_kv_cache + + # Speculative decoding support: allocate intermediate state tensors + # outside the unified pool for caching per-draft-token snapshots. + self._spec_config = spec_config + if spec_config is not None: + speculative_num_draft_tokens = spec_config.max_draft_len + num_local_mamba_layers = len(self.mamba_pp_layers) + ssm_state_shape_tuple = tuple(self.ssm_state_shape) + conv_state_shape_tuple = tuple(self.conv_state_shape) + + self._intermediate_ssm_states = torch.zeros( + size=(num_local_mamba_layers, max_batch_size, + speculative_num_draft_tokens + 1) + ssm_state_shape_tuple, + dtype=self.ssm_state_dtype, + device="cuda", + ) + + self._intermediate_conv_states = torch.zeros( + size=(num_local_mamba_layers, max_batch_size, + speculative_num_draft_tokens + 1) + + conv_state_shape_tuple, + dtype=self.conv_state_dtype, + device="cuda", + ) + + self._intermediate_state_indices = torch.arange(max_batch_size, + dtype=torch.int32, + device="cuda") + + logger.info( + f"CppMambaHybridCacheManager speculative buffers allocated. " + f"intermediate_ssm size: {get_tensor_size_bytes(self._intermediate_ssm_states) / GB:.2f}GB, " + f"intermediate_conv size: {get_tensor_size_bytes(self._intermediate_conv_states) / GB:.2f}GB" + ) + else: + self._intermediate_ssm_states = None + self._intermediate_conv_states = None + self._intermediate_state_indices = None + + def shutdown(self): + # Release tensor views into the pool before the pool memory is freed, + # so their deleters don't see stale pointers. + self.all_ssm_states = None + self.all_conv_states = None + self._intermediate_ssm_states = None + self._intermediate_conv_states = None + self._intermediate_state_indices = None + super().shutdown() + + def add_dummy_requests( + self, + request_ids: List[int], + # Note that token_nums should be past_kv_len + input_len (without + # spec decoding). The draft tokens will be added in this function, + # so we don't need to take care of it in the caller. When preparing + # token_nums, we should not take the draft tokens into account, so + # don't use the kv_cache_manager.max_seq_len, which includes both + # extra tokens and draft tokens. + token_nums: Optional[List[int]] = None, + is_gen: bool = False, + prepare_resource: bool = True, + max_num_draft_tokens: int = 0, + use_mrope: bool = False, + max_beam_width: int = 1, + # For capturable drafting loops. During normal inference, the draft model always + # has enough KV cache space to fit all of our draft tokens. During warmup, however, + # we need to make the KV cache manager aware that multiple autoregressive steps will + # occur. + num_extra_decoding_steps: int = 0, + draft_kv_cache_manager: Optional[KVCacheManager] = None, + ) -> List[LlmRequest]: + requests = super().add_dummy_requests(request_ids, token_nums, is_gen, + prepare_resource, + max_num_draft_tokens, use_mrope, + max_beam_width, + num_extra_decoding_steps, + draft_kv_cache_manager) + self.requests.extend(requests) + self._setup_state_indices() + return requests + + def update_resources(self, + scheduled_batch: ScheduledRequests, + attn_metadata: "AttentionMetadata" = None, + kv_cache_dtype_byte_size: float = None): + super().update_resources(scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + + @nvtx_range("hybrid_prepare_resources") + def _prepare_resources(self, scheduled_batch: ScheduledRequests): + self.requests = scheduled_batch.context_requests + \ + scheduled_batch.generation_requests + for req in self.requests: + self.impl.copy_linear_attention_block(req) + print(f"req {req.py_request_id}:") + print( + f" Cache indices: {self.get_cache_indices(req, LinearCacheType.RECURRENT_STATES.value)}" + ) + self.impl.refresh_blocks() + self._setup_state_indices() + + def prepare_resources(self, scheduled_batch: ScheduledRequests): + print("--------") + super().prepare_resources(scheduled_batch) + self._prepare_resources(scheduled_batch) + + def is_speculative(self) -> bool: + return self._spec_config is not None + + def update_mamba_states(self, attn_metadata: "AttentionMetadata", + num_accepted_tokens: torch.Tensor): + # Note: cannot use @torch.compile here because all_ssm_states and + # all_conv_states are dtype-reinterpreted views of the C++ pool + # (uint8 -> typed), and aot_autograd does not support mutations on + # views with different dtypes. + batch_size = attn_metadata.num_seqs + num_contexts = attn_metadata.num_contexts + num_gens = batch_size - num_contexts + num_accepted_draft_tokens = num_accepted_tokens[ + num_contexts:num_contexts + num_gens] - 1 + state_indices_d = self.get_state_indices()[num_contexts:num_contexts + + num_gens] + + src_state_indices = self._intermediate_state_indices[:num_gens] + + # Copy accepted SSM states from intermediate buffer back to pool + accepted_ssm = self._intermediate_ssm_states[:, src_state_indices, + num_accepted_draft_tokens] + self.all_ssm_states[:, state_indices_d, :] = accepted_ssm + + # Copy accepted conv states from intermediate buffer back to pool + accepted_conv = self._intermediate_conv_states[:, src_state_indices, + num_accepted_draft_tokens] + self.all_conv_states[:, state_indices_d, :] = accepted_conv + + def get_ssm_states(self, layer_idx: int) -> torch.Tensor: + return self.all_ssm_states[self.mamba_layer_offsets[layer_idx]] + + def get_conv_states(self, layer_idx: int) -> torch.Tensor: + return self.all_conv_states[self.mamba_layer_offsets[layer_idx]] + + def get_intermediate_ssm_states(self, + layer_idx: int) -> Optional[torch.Tensor]: + if self._intermediate_ssm_states is None: + return None + layer_offset = self.mamba_layer_offsets[layer_idx] + return self._intermediate_ssm_states[layer_offset] + + def get_intermediate_conv_states(self, + layer_idx: int) -> Optional[torch.Tensor]: + if self._intermediate_conv_states is None: + return None + layer_offset = self.mamba_layer_offsets[layer_idx] + return self._intermediate_conv_states[layer_offset] + + def mamba_layer_cache( + self, layer_idx: int + ) -> Union[PythonMambaCacheManager.State, + PythonMambaCacheManager.SpeculativeState, None]: + conv = self.get_conv_states(layer_idx) + ssm = self.get_ssm_states(layer_idx) + if self._spec_config is not None: + layer_offset = self.mamba_layer_offsets[layer_idx] + return PythonMambaCacheManager.SpeculativeState( + conv=conv, + temporal=ssm, + intermediate_ssm=self._intermediate_ssm_states[layer_offset], + intermediate_conv_window=self. + _intermediate_conv_states[layer_offset], + ) + return PythonMambaCacheManager.State(conv=conv, temporal=ssm) + + def free_resources(self, request: LlmRequest, pin_on_release: bool = False): + if request in self.requests: + self.requests.remove(request) + super().free_resources(request, pin_on_release) + + def _setup_state_indices(self) -> None: + if not self.mamba_pp_layers: + return + + block_indices = [] + for req in self.requests: + if req.is_context_finished: + next_step = self.get_num_tokens(req) - 1 + elif self.kv_cache_config.enable_block_reuse: + next_step = (req.context_current_position - 1 + + req.context_chunk_size) + else: + next_step = req.prompt_len - 1 + block_indices.append(next_step // self.tokens_per_block) + self.impl.copy_batch_block_offsets( + self.host_block_offsets, + [req.py_request_id for req in self.requests], 1, 0) + host_block_offsets = torch.zeros([len(self.requests)], + dtype=torch.int32, + device="cpu") + for i in range(len(self.requests)): + # With layer-first pool layout, setOffsets produces the block index directly + # (no longer multiplied by num_mamba_layers) + value = self.host_block_offsets[self.recurrent_states_pool_index, i, + 0, block_indices[i]] + max_blocks = self.blocks_per_window[ + LinearCacheType.RECURRENT_STATES.value][0] + if value < 0 or value >= max_blocks: + raise RuntimeError( + f"Invalid recurrent state block index {value} " + f"(expected 0 <= index < {max_blocks}) for request {i}") + host_block_offsets[i] = value + + torch.fill_(self.cuda_state_indices, 0) + self.cuda_state_indices[:len(self.requests)] = host_block_offsets.cuda() + self._host_state_indices = host_block_offsets.clone() + + def get_state_indices( + self, + request_ids: Optional[List[int]] = None, + is_padding: Optional[List[bool]] = None) -> torch.Tensor: + return self.cuda_state_indices + + def calc_next_context_chunk_size(self, request: LlmRequest) -> int: + """Compute the next prefill chunk size for a context request when block reuse is enabled. + + When kv_cache_config.enable_block_reuse is True, context prefill must stop exactly at + the positions returned by calc_context_stop_positions (mamba_state_cache_interval boundaries + and block boundaries). This returns the chunk_size to use for the next prefill step so + that the next stop position is not exceeded. + + Args: + request: Context request with prompt_len and context_current_position set. + + Returns: + Number of tokens to prefill in the next step (0 if context is already complete). + """ + prompt_len = request.prompt_len + current = request.context_current_position + if current >= prompt_len: + return 0 + if not self.kv_cache_config.enable_block_reuse: + assert current == 0, f"Expected context_current_position to be 0 when block reuse is disabled, but got {current}" + return prompt_len - current + step = self.linear_attention_metadata.states_snapshot_interval + stop_positions = calc_context_stop_positions(prompt_len, + self.tokens_per_block, + step) + stop_positions = sorted(set(stop_positions)) + for pos in stop_positions: + if pos > current: + return pos - current + return prompt_len - current + + def _setup_states_views(self) -> None: + # Pool layout: {numLayers, numBlocks, ssm_bytes + conv_bytes} (as uint8) + pool: torch.Tensor = self.impl.get_recurrent_states_pool().view( + torch.uint8).reshape(self.num_mamba_layers, -1, + self.ssm_bytes + self.conv_bytes) + num_blocks_in_pool = pool.shape[1] + self.all_ssm_states = pool[:, :, :self.ssm_bytes].view( + self.ssm_state_dtype).view( + [self.num_mamba_layers, num_blocks_in_pool] + + self.ssm_state_shape) + self.all_conv_states = pool[:, :, self.ssm_bytes:self.ssm_bytes + + self.conv_bytes].view( + self.conv_state_dtype).view([ + self.num_mamba_layers, + num_blocks_in_pool + ] + self.conv_state_shape) + + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + return self.ssm_state_dtype + + +class _MambaHybridCacheManagerMeta(type): + """Metaclass that enables isinstance/issubclass checks against + MambaHybridCacheManager for both Mixed and Cpp implementations.""" + + def __instancecheck__(cls, instance): + if cls is MambaHybridCacheManager: + return isinstance( + instance, + (MixedMambaHybridCacheManager, CppMambaHybridCacheManager)) + return super().__instancecheck__(instance) + + def __subclasscheck__(cls, subclass): + if cls is MambaHybridCacheManager: + return issubclass( + subclass, + (MixedMambaHybridCacheManager, CppMambaHybridCacheManager)) + return super().__subclasscheck__(subclass) + + def __getattr__(cls, name): + """Forward class-level attribute access (e.g. static methods) to + KVCacheManager. Add attributes here as needed.""" + return getattr(KVCacheManager, name) + + +class MambaHybridCacheManager(metaclass=_MambaHybridCacheManagerMeta): + """Factory that selects the appropriate hybrid cache manager. + + Selection logic: + - TRTLLM_USE_CPP_MAMBA=1 (disaggregated serving) -> MixedMambaHybridCacheManager + - Otherwise (default, including speculative decoding) -> CppMambaHybridCacheManager + """ + + def __new__( + cls, + # mamba cache parameters + mamba_d_state: int, + mamba_d_conv: int, + mamba_num_heads: int, + mamba_n_groups: int, + mamba_head_dim: int, + mamba_num_layers: int, + mamba_layer_mask: List[bool], + mamba_cache_dtype: torch.dtype, + mamba_ssm_cache_dtype: torch.dtype, + # kv cache parameters + kv_cache_config: KvCacheConfig, + kv_cache_type: CacheTypeCpp, + **kwargs, + ): + positional_args = ( + mamba_d_state, + mamba_d_conv, + mamba_num_heads, + mamba_n_groups, + mamba_head_dim, + mamba_num_layers, + mamba_layer_mask, + mamba_cache_dtype, + mamba_ssm_cache_dtype, + kv_cache_config, + kv_cache_type, + ) + + if mamba_num_layers == 0: + logger.info( + "mamba_num_layers is 0, using KVCacheManager without mamba caching" + ) + # kwargs.pop("") + return KVCacheManager(kv_cache_config, kv_cache_type, **kwargs) + + use_mixed = use_cpp_mamba_cache_manager() + + if use_mixed: + logger.info( + "Using MixedMambaHybridCacheManager for hybrid cache management" + ) + return MixedMambaHybridCacheManager(*positional_args, **kwargs) + else: + logger.info( + "Using CppMambaHybridCacheManager for hybrid cache management") + return CppMambaHybridCacheManager(*positional_args, **kwargs) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 783aa17b9c5..5f52aefff0f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1936,6 +1936,7 @@ def _executor_loop(self): iter_start_time = time.time() iter_stats = None while True: + print("loop") self.hang_detector.checkpoint() profile_step() if self.enable_iter_perf_stats: @@ -2834,7 +2835,10 @@ def _waiting_requests(self, context_requests: list[LlmRequest], def _schedule(self): scheduler_output = self.scheduler.schedule_request( self.active_requests, self.inflight_req_ids) - + print( + f"self.active_requests {[req.py_request_id for req in self.active_requests]}" + ) + print(f"scheduler_output: {scheduler_output}") scheduled_context_requests = scheduler_output.context_requests if self.enable_attention_dp and self.attention_dp_enable_balance: scheduled_context_requests = self._balance_adp_requests( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index e0aa739d869..a08318149bf 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -36,7 +36,7 @@ from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, create_py_executor_instance, instantiate_sampler, is_mla, validate_feature_combination) -from .config_utils import is_nemotron_hybrid, is_qwen3_hybrid +from .config_utils import is_hybrid_linear from .dwdp import DwdpManager from .guided_decoder import CapturableGuidedDecoder, GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager @@ -538,10 +538,12 @@ def drafting_loop_wrapper(model): cache_transceiver_config.max_tokens_in_buffer = net_max_seq_len config = model_engine.model.model_config.pretrained_config - if (is_nemotron_hybrid(config) - or is_qwen3_hybrid(config)) and kv_cache_config.enable_block_reuse: + if is_hybrid_linear(config) and kv_cache_config.enable_block_reuse and ( + cache_transceiver_config is not None + and cache_transceiver_config.backend is not None): logger.warning( - "Disabling block reuse for MambaHybridCacheManager-based models") + "Disabling block reuse for MambaHybridCacheManager-based models when disagg is enabled" + ) kv_cache_config.enable_block_reuse = False _set_model_engines_cache_reuse([model_engine, draft_model_engine], False) @@ -598,6 +600,10 @@ def drafting_loop_wrapper(model): else: ctx_chunk_config = None + if kv_cache_config.enable_block_reuse and is_hybrid_linear(config): + ctx_chunk_config = (ContextChunkingPolicy.FORCE_CHUNK, + kv_cache_config.mamba_state_cache_interval) + guided_decoder: Optional[GuidedDecoder] = None if guided_decoding_config is not None: with allocation_scope(ExecutorMemoryType.GUIDED_DECODER): @@ -719,7 +725,7 @@ def drafting_loop_wrapper(model): is_disagg = (cache_transceiver_config is not None and cache_transceiver_config.backend is not None) - is_hybrid = is_nemotron_hybrid(config) or is_qwen3_hybrid(config) + is_hybrid = is_hybrid_linear(config) if is_disagg and is_hybrid: if cache_transceiver_config.transceiver_runtime != "PYTHON" or os.environ.get( diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 971ff8c5402..7c9eba15331 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -16,11 +16,13 @@ from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor, get_size_in_bytes, mpi_comm, mpi_disabled, prefer_pinned, torch_comm) -from tensorrt_llm.bindings.internal.batch_manager import KvCacheStats +from tensorrt_llm.bindings.internal.batch_manager import ( + KvCacheStats, LinearAttentionMetadata, LinearCacheType) from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( IndexMapper, copy_batch_block_offsets_to_device) from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig -from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig +from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig, + PybindMirror) from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig from tensorrt_llm.runtime import ModelConfig as ModelConfigPython @@ -280,6 +282,7 @@ def __init__( indexer_k_cache_index_head_dim: int = 0, is_estimating_kv_cache: bool = False, execution_stream: Optional[torch.cuda.Stream] = None, + linear_attention_metadata: Optional[LinearAttentionMetadata] = None, **kwargs, ) -> None: self.mapping = mapping @@ -354,12 +357,16 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], self.max_draft_len = spec_config.max_draft_len if spec_config is not None else 0 self.max_total_draft_tokens = (spec_config.tokens_per_gen_step - 1) if spec_config is not None else 0 + self.linear_attention_metadata = linear_attention_metadata # Determine max_attention_window_vec if kv_cache_config.max_attention_window is None: # Use max_seq_len as default max_attention_window self.max_attention_window_vec = [max_seq_len] else: + print( + f"Original max_attention_window from config: {kv_cache_config.max_attention_window}" + ) self.max_attention_window_vec = kv_cache_config.max_attention_window.copy( ) # Make a copy to avoid modifying original # Clamp all window sizes to max_seq_len before calculating the @@ -373,8 +380,12 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], if kv_cache_config.sink_token_length is not None else 0) - # Determine if this is VSWA (Variable Sliding Window Attention) - self.is_vswa = len(set(self.max_attention_window_vec)) > 1 + # Determine if this is VSWA (Variable Sliding Window Attention). + # The `w > 0` check excludes LinearCacheType.RECURRENT_STATES sentinel + # values (negative) used by hybrid linear attention models. + self.is_vswa = len(set(self.max_attention_window_vec)) > 1 and all( + w > 0 for w in self.max_attention_window_vec) + self.is_linear_attention = linear_attention_metadata is not None # Calculate kv cache blocks for each window size # FIXME: flashinfer.py accesses kv_cache_manager.blocks_in_primary_pool @@ -385,6 +396,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], # max_tokens under _util.py::try_prepare_estimation # Since this is a dry run, assigning the same max_tokens capacity # to all window sizes as they are full attentions is enough. + self.blocks_in_primary_pool = int(kv_cache_config.max_tokens // tokens_per_block) @@ -399,16 +411,56 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], (self.blocks_in_primary_pool, self.blocks_in_secondary_pool) for window_size in set(self.max_attention_window_vec) } + if self.is_linear_attention: + if len(self.max_attention_window_vec) != self.num_layers: + print( + f"Original max_attention_window_vec: {self.max_attention_window_vec}" + ) + # self.max_attention_window_vec is a pattern, repeat it to match num_layers + self.max_attention_window_vec = ( + self.max_attention_window_vec * + (self.num_layers // len(self.max_attention_window_vec) + + 1))[:self.num_layers] + print( + f"Adjusted max_attention_window_vec for linear attention: {self.max_attention_window_vec}" + ) + # _util.py::try_prepare_estimation can't estimate linear attentions properly + num_linear_layers = sum( + 1 if self.max_attention_window_vec[layer] == + LinearCacheType.RECURRENT_STATES.value else 0 + for layer in self.pp_layers) + bytes_per_linear_block = linear_attention_metadata.all_recurrent_states_bytes * num_linear_layers + num_attention_layers = self.num_local_layers - num_linear_layers + # get_cache_bytes_per_token() calculates assuming all layers are full attention layers + total_bytes_per_token = self.get_cache_bytes_per_token( + ) * num_attention_layers // self.num_local_layers + total_bytes_per_token += bytes_per_linear_block * self.max_batch_size // kv_cache_config.max_tokens + max_snapshots = self.max_batch_size + if kv_cache_config.enable_block_reuse: + total_bytes_per_token += bytes_per_linear_block // linear_attention_metadata.states_snapshot_interval + + expand_factor = total_bytes_per_token / self.get_cache_bytes_per_token( + ) + + kv_cache_config.max_tokens = int(kv_cache_config.max_tokens // + expand_factor) + self.blocks_in_primary_pool = int(kv_cache_config.max_tokens // + tokens_per_block) + blocks_per_window[self.max_seq_len] = ( + self.blocks_in_primary_pool, self.blocks_in_secondary_pool) + if kv_cache_config.enable_block_reuse: + max_snapshots = max( + kv_cache_config.max_tokens // + linear_attention_metadata.states_snapshot_interval, + self.max_batch_size) + + blocks_per_window[LinearCacheType.RECURRENT_STATES.value] = ( + int(max_snapshots), 0) logger.info( f"[kv cache manager] Primary/secondary blocks for window sizes set to {blocks_per_window} for estimation dry run" ) else: - if self.is_vswa: - # VSWA case: use C++ implementation for variable window sizes - if model_config is None: - raise ValueError( - "model_config is required for VSWA (Variable Sliding Window Attention)" - ) + if self.is_vswa or self.is_linear_attention: assert isinstance( kv_cache_config, KvCacheConfig ), "calculate_max_num_blocks_for_vswa only accepts KvCacheConfig" @@ -495,6 +547,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], self._stream = execution_stream if execution_stream is not None else torch.cuda.Stream( ) logger.info(f"[KVCacheManager] execution_stream: {self._stream}") + logger.info(f"[KVCacheManager] blocks_per_window: {blocks_per_window}") kwargs = { 'num_kv_heads_per_layer': self.num_kv_heads_per_layer, 'size_per_head': head_dim, @@ -507,7 +560,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], 'dtype': dtype, 'sink_token_length': sink_token_length, 'stream': self._stream.cuda_stream, # Pass to BufferManager - 'max_sequence_length': max_seq_len, + 'max_sequence_length': self.max_seq_len, 'enable_block_reuse': kv_cache_config.enable_block_reuse, 'onboard_blocks': kv_cache_config.onboard_blocks, 'cache_type': kv_cache_type, @@ -517,7 +570,8 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], 'enable_indexer_k_cache': enable_indexer_k_cache, 'indexer_k_cache_quant_block_size': indexer_k_cache_quant_block_size, - 'indexer_k_cache_index_head_dim': indexer_k_cache_index_head_dim + 'indexer_k_cache_index_head_dim': indexer_k_cache_index_head_dim, + 'linear_attention_metadata': linear_attention_metadata } if self.event_buffer_max_size > 0: @@ -561,6 +615,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], dtype=torch.int32, pin_memory=prefer_pinned(), device='cpu') + self.blocks_per_window = blocks_per_window def probe_prefix_match_length(self, input_tokens, lora_task_id=None): """Probe the KV cache radix tree for prefix match length. @@ -599,6 +654,10 @@ def shutdown(self): def get_max_resource_count(self) -> int: return self.impl.max_num_blocks + def get_num_tokens(self, request: LlmRequest) -> int: + # LlmRequest.get_num_tokens is out of sync with GenerationRequest when overlap scheduler is enabled. + return self.impl.get_token_count(request.py_request_id) + def get_needed_resource_to_completion(self, request: LlmRequest) -> int: # TODO: the C++ implementation of this method can be used, but the # Python and C++ schedulers currently do not agree on what "needed @@ -1058,9 +1117,9 @@ def get_batch_cache_indices( return result def get_num_free_blocks(self) -> int: - if self.is_vswa: + if self.is_vswa or self.is_linear_attention: logger.info( - f"For VSWA case, we return the minimum of the number of free blocks for each window size: {self.impl.get_kv_cache_stats().num_free_blocks_per_window_size}" + f"For {'linear attention' if self.is_linear_attention else 'VSWA'} case, we return the minimum of the number of free blocks for each window size: {self.impl.get_kv_cache_stats().num_free_blocks_per_window_size}" ) return min(self.impl.get_kv_cache_stats(). num_free_blocks_per_window_size.values()) @@ -1351,7 +1410,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: def calculate_max_num_blocks_for_vswa( self, kv_cache_config: KvCacheConfig, - model_config: ModelConfigCpp, + model_config: Optional[ModelConfigCpp], extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]: """ Currently, this function is added to support *ONLY* VSWA. @@ -1377,7 +1436,6 @@ def calculate_max_num_blocks_for_vswa( # VSWA on Torch backend has not supported the cross attention. is_cross_attention = False # check model config - assert model_config.layer_types is not None, "layer_types have to be set correctly for VSWA" # Construct WorldConfig from self.mapping world_config_cpp = WorldConfig( @@ -1401,19 +1459,45 @@ def calculate_max_num_blocks_for_vswa( f"secondary_pool_memory_bytes is set to {self._secondary_pool_memory_bytes/1024**3}GB" ) - # Adjust the window sizes to fit the memory if even a single sequence - # cannot fit in the memory. - window_size_to_layers, max_attention_window_vec = self.adjust_window_sizes_for_vswa( - window_size_to_layers=window_size_to_layers, - max_attention_window_vec=self.max_attention_window_vec, - model_config=model_config, - kv_cache_config=kv_cache_config, - pool_memory_bytes=self._primary_pool_memory_bytes, - kv_factor=self.kv_factor, - dtype=self.dtype, - is_cross_attention=is_cross_attention, - ) - self.max_attention_window_vec = max_attention_window_vec + if self.is_linear_attention: + blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks( + config=PybindMirror.maybe_to_pybind(kv_cache_config), + dtype=self.dtype, + num_kv_heads_per_layer=list(self.num_kv_heads_per_layer), + size_per_head=self.head_dim, + tokens_per_block=self.tokens_per_block, + world_config=world_config_cpp, + window_size_to_layers=window_size_to_layers, + allotted_primary_mem_bytes=self._primary_pool_memory_bytes, + allotted_secondary_mem_bytes=self._secondary_pool_memory_bytes, + extra_cost_memory=extra_cost_memory, + kv_factor=self.kv_factor, + max_batch_size=self.max_batch_size, + linear_attention_metadata=PybindMirror.maybe_to_pybind( + self.linear_attention_metadata), + ) + return blocks_per_window + + # VSWA case: use C++ implementation for variable window sizes + if model_config is None: + raise ValueError( + "model_config is required for VSWA (Variable Sliding Window Attention)" + ) + assert model_config.layer_types is not None, "layer_types have to be set correctly for VSWA" + if self.is_vswa: + # Adjust the window sizes to fit the memory if even a single sequence + # cannot fit in the memory. + window_size_to_layers, max_attention_window_vec = self.adjust_window_sizes_for_vswa( + window_size_to_layers=window_size_to_layers, + max_attention_window_vec=self.max_attention_window_vec, + model_config=model_config, + kv_cache_config=kv_cache_config, + pool_memory_bytes=self._primary_pool_memory_bytes, + kv_factor=self.kv_factor, + dtype=self.dtype, + is_cross_attention=is_cross_attention, + ) + self.max_attention_window_vec = max_attention_window_vec def calculate_cache_size_per_token(layers: Set[int]) -> int: # Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py index 4f2d56c657b..c6684c4ed2b 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py @@ -284,6 +284,8 @@ def schedule_request( self.capacity_scheduler.schedule_request(active_requests) ) + print(f"fitting_requests: {[req.request_id for req in fitting_requests]}") + context_requests, generation_requests = self.micro_batch_scheduler.schedule( fitting_requests, inflight_request_ids ) @@ -434,7 +436,7 @@ def schedule( ): break - logger.debug( + logger.info( f"context request scheduled: ID {req.request_id}" + (f" (reusable {reusable})" if reusable > 0 else "") ) @@ -459,7 +461,7 @@ def schedule( compute_tokens = max_context_length all_context_requests_fit = False - logger.debug( + logger.info( f"contexts-to-be-chunked request scheduled: ID {req.request_id}" + (f" (reusable {reusable})" if reusable > 0 else "") ) @@ -482,7 +484,7 @@ def schedule( if scheduled_beam_width == 0: scheduled_beam_width = beam_width elif scheduled_beam_width != beam_width: - logger.debug( + logger.info( f"generation request skipped: ID {req.request_id} since its " f"beam width ({beam_width}) is different from scheduled ones " f"({scheduled_beam_width})" @@ -524,7 +526,7 @@ def schedule( reusable = req.estimated_reusable_tokens if req.is_first_context_chunk else 0 compute_tokens = max(0, req.context_chunk_size - reusable) batch_num_tokens += compute_tokens - logger.debug( + logger.info( f"context request scheduled: ID {req.request_id}, " f"chunk size {req.context_chunk_size}" + (f", reusable {reusable}" if reusable > 0 else "") @@ -535,11 +537,11 @@ def schedule( self._sort_requests(context_requests, generation_requests, not all_context_requests_fit) # Summary logs - logger.debug( + logger.info( f"batchSize (num ctx/enc requests + num gen requests): " f"{len(context_requests) + len(generation_requests)}" ) - logger.debug( + logger.info( f"batchNumTokens (num ctx/enc input tokens + num gen input tokens) " f"/ maxNumTokens: {batch_num_tokens} / {max_num_tokens or 0}" ) @@ -740,6 +742,9 @@ def _chunk_forced(self, requests: RequestList, capacity: Optional[int], unit_siz for req in requests: req.context_chunk_size = min(req.context_remaining_length, unit_size) if capacity is not None and total_tokens + req.context_chunk_size > capacity: + print( + f"Request ID {req.request_id} chunk size reduced to 0 to fit capacity {capacity}" + ) req.context_chunk_size = 0 total_tokens += req.context_chunk_size assert capacity is None or total_tokens <= capacity @@ -1491,6 +1496,12 @@ def schedule_request( self.capacity_scheduler.schedule_request(active_requests) ) + print( + f"After capacity scheduling: {len(fitting_requests)} fitting requests, " + f"{len(fitting_disagg_gen_init)} fitting disagg gen init requests, " + f"{len(paused_requests)} paused requests" + ) + # Step 2: MicroBatch Check (Who fits in token budget? + Chunking) context_requests, generation_requests = self.micro_batch_scheduler.schedule( fitting_requests, inflight_request_ids diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 9f001b4e5ae..5f2585d43f0 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -27,6 +27,7 @@ except ImportError: PlacementGroup = None +from tensorrt_llm.bindings.internal.batch_manager import LinearCacheType from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules) @@ -2186,7 +2187,7 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): description= "The maximum number of tokens that should be stored in the KV cache. If both `max_tokens` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be used." ) - max_attention_window: Optional[List[PositiveInt]] = Field( + max_attention_window: Optional[List[int]] = Field( default=None, min_length=1, description= @@ -2281,6 +2282,12 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): tokens_per_block: int = Field(default=32, description="The number of tokens per block.") + # This is a pure python field, not a pybind field. It is only for the Pytorch backend. + mamba_state_cache_interval: PositiveInt = Field( + default=256, + description= + "The number of tokens between cache steps in the Mamba prefix cache.") + use_kv_cache_manager_v2: bool = Field( default=False, status="prototype", @@ -2359,9 +2366,9 @@ def validate_max_attention_window(cls, v: Optional[List[int]]): raise ValueError( "kv_cache_config.max_attention_window must contain only integers" ) - if i <= 0: + if i <= 0 and i not in [LinearCacheType.RECURRENT_STATES.value]: raise ValueError( - "kv_cache_config.max_attention_window values must be positive" + "kv_cache_config.max_attention_window values must be positive or LinearCacheType.RECURRENT_STATES.value" ) return v diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 2c88b1e11f8..60283054b10 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -5892,27 +5892,35 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph, task.evaluate(llm) @skip_pre_blackwell - @pytest.mark.skip_less_device(4) @pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRTLLM"], ids=["cutlass", "trtllm"]) @pytest.mark.parametrize( - "tp_size,pp_size,ep_size,cuda_graph,overlap_scheduler,attention_dp", [ - (1, 1, 1, True, True, False), - (4, 1, 1, True, True, False), - (4, 1, 4, True, True, True), - (4, 1, 4, True, True, False), - (4, 1, 4, False, False, False), + "tp_size,pp_size,ep_size,cuda_graph,overlap_scheduler,attention_dp,enable_block_reuse", + [ + (1, 1, 1, True, True, False, True), + (1, 1, 1, True, True, False, False), + (4, 1, 1, True, True, False, False), + (4, 1, 4, True, True, True, False), + (4, 1, 4, True, True, False, False), + (4, 1, 4, False, False, False, False), ], ids=[ - "tp1", "tp4ep1", "tp4ep4_adp_on", "tp4ep4_adp_off", - "no_cuda_graph_overlap" + "tp1_block_reuse", "tp1", "tp4ep1", "tp4ep4_adp_on", + "tp4ep4_adp_off", "no_cuda_graph_overlap" ]) def test_nvfp4(self, moe_backend, tp_size, pp_size, ep_size, cuda_graph, - overlap_scheduler, attention_dp, mocker): + overlap_scheduler, attention_dp, enable_block_reuse, mocker): + gpu_needed = max(tp_size, ep_size) * pp_size + if get_device_count() < gpu_needed: + pytest.skip( + f"Device count {get_device_count()} is less than required {gpu_needed}" + ) model_path = f"{self.MODEL_PATH}/qwen3-next-80b-instruct-nvfp4-ptq-fp8kv" kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, - enable_block_reuse=False) + enable_block_reuse=enable_block_reuse) + if enable_block_reuse: + kv_cache_config.mamba_state_cache_interval = 256 pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig( max_batch_size=512, enable_padding=False) @@ -6026,7 +6034,8 @@ def test_bf16(self): task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) - def test_fp8(self): + @parametrize_with_ids("enable_block_reuse", [False, True]) + def test_fp8(self, enable_block_reuse): model_dir = f"{self.MODEL_PATH}-FP8" # Model is being added to CI. Skip at the moment. if not os.path.exists(model_dir): @@ -6034,7 +6043,7 @@ def test_fp8(self): world_size = 1 kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8, - enable_block_reuse=False) + enable_block_reuse=enable_block_reuse) moe_config = MoeConfig(backend='DEEPGEMM') with LLM(model_dir, @@ -6682,6 +6691,49 @@ def test_nvfp4_8gpus(self, attention_dp, moe_backend): task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + @skip_pre_blackwell + @pytest.mark.skip_less_mpi_world_size(4) + @pytest.mark.parametrize( + "tp_size, ep_size, mamba_state_cache_interval, attention_dp", + [ + (4, 1, 256, False), + (4, 4, 512, False), + (4, 4, 256, True), + ], + ids=["TP4", "TEP4", "TEP4_ADP"], + ) + def test_nvfp4_4gpus_block_reuse(self, tp_size, ep_size, + mamba_state_cache_interval, attention_dp): + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=3, + mtp_eagle_one_model=True, + ) + with LLM( + f"{llm_models_root()}/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4", + kv_cache_config=KvCacheConfig( + enable_block_reuse=True, + mamba_ssm_cache_dtype="float16", + mamba_state_cache_interval=mamba_state_cache_interval, + free_gpu_memory_fraction=0.8, + ), + max_batch_size=32, + tensor_parallel_size=tp_size, + moe_expert_parallel_size=ep_size, + pipeline_parallel_size=1, + enable_attention_dp=attention_dp, + cuda_graph_config=CudaGraphConfig(max_batch_size=32, + enable_padding=True), + disable_overlap_scheduler=False, + moe_config=MoeConfig(backend="TRTLLM"), + speculative_config=mtp_config, + ) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + @skip_pre_blackwell @pytest.mark.skip_less_mpi_world_size(8) @pytest.mark.parametrize( @@ -6733,7 +6785,7 @@ def test_nvfp4_8gpus_mtp(self): with LLM( model_path, kv_cache_config=KvCacheConfig( - enable_block_reuse=False, + enable_block_reuse=True, mamba_ssm_cache_dtype="float16", free_gpu_memory_fraction=0.5, ), diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 229a7f1e180..5c3a40fd82d 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -192,7 +192,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_tr accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16 -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=False] +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=True] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-triton-auto] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index 1378ce4efa4..a05426de9b4 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -176,7 +176,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency] accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16 -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=False] +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=True] # disaggregated serving accuracy test accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 8d1a900aa0c..4c1b2819153 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -73,8 +73,9 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16 - - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 - - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=False] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=True] + - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1_block_reuse-cutlass] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551 - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8] diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 3cb1cefbfaa..ea41176a9eb 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -58,6 +58,8 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=CUTLASS] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_online_eplb[moe_backend=TRTLLM] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_online_eplb[moe_backend=CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_block_reuse[TEP4] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_block_reuse[TEP4_ADP] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 0cfeddceaab..84ed7a838b3 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -326,7 +326,8 @@ accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B_Instruct_Eagle3::test_eagle accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[fp8-4-trtllm] SKIP (https://nvbugs/5997046) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[ep4-mtp_nextn=0] SKIP (https://nvbugs/5997051) perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_v32_fp4_blackwell-v32_fp4_tep8_mtp3_8k1k] SKIP (https://nvbugs/5997092) -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 SKIP (https://nvbugs/6004530) +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=False] SKIP (https://nvbugs/6004530) +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=True] SKIP (https://nvbugs/6004530) unittest/_torch/modules/moe/test_moe_module.py::test_configurable_moe_multi_gpu[parallel=DEP-comm=DEEPEP-e60_k4_h2048_i1408-seq=8-dtype=torch.bfloat16-backend=TRTLLM-quant=NVFP4-routing=Renormalize] SKIP (https://nvbugs/6007285) disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b] SKIP (https://nvbugs/6011317) accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_bf16_4gpu_mtp_ar SKIP (https://nvbugs/5959992) diff --git a/tests/unittest/_torch/executor/test_kv_cache_manager.py b/tests/unittest/_torch/executor/test_kv_cache_manager.py new file mode 100644 index 00000000000..93721254ba8 --- /dev/null +++ b/tests/unittest/_torch/executor/test_kv_cache_manager.py @@ -0,0 +1,423 @@ +# ruff: noqa: E501 +from functools import reduce + +import torch + +import tensorrt_llm +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_qwen3_next import Qwen3NextConfig +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.bindings.internal.batch_manager import ( + CacheType, + LinearAttentionMetadata, + LinearCacheType, +) +from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig, MTPDecodingConfig, SchedulerConfig +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.sampling_params import SamplingParams + +text_questions = [ + # "Question: Mark wants to tip his server 20% on a $200 check. If his friend agrees to kick in $10, how much should Mark add?", + # "Question: The dinner bill for 6 friends came to $150. Silas said he would pay for half of the bill and the remaining friends could split the rest of the bill and leave a 10% tip for the whole meal. How many dollars will one of the friends pay?", + # "Question: Nancy takes 3 antacids per day when she eats Indian food, 2 antacids per day when she eats Mexican food, and 1 antacid per day otherwise. If Nancy eats Indian three times a week and Mexican twice a week, how many antacids does she take per month?", + # "Question: Mr. and Mrs. Boyden take their 3 children to a leisure park. They buy tickets for the whole family. The cost of an adult ticket is $6 more than the cost of a child ticket. The total cost of the 5 tickets is $77. What is the cost of an adult ticket?", + "Question: Mark builds an apartment that is 16 by 10 feet. There are 6 rooms in total. All the rooms are the same size except the living room which is as big as 3 other rooms. How big is the living room? Answer:", + # "Question: Adrien's total salary was 30 percent higher than Lylah's. Four years later, his salary had increased, and he was earning 40% more than what he was making four years ago. If Adrien's and Lylah's salary increased simultaneously, and Adrien earned $40000 four years ago, calculate the total salary the two were receiving four years later?", +] + +text_gsm = """Question: Mark wants to tip his server 20% on a $200 check. If his friend agrees to kick in $10, how much should Mark add? +Answer: First find the total tip amount: 20% * $200 = $<<20*.01*200=40>>40 +Then subtract the friend's contribution: $40 - $10 = $<<40-10=30>>30 +#### 30 + +Question: The dinner bill for 6 friends came to $150. Silas said he would pay for half of the bill and the remaining friends could split the rest of the bill and leave a 10% tip for the whole meal. How many dollars will one of the friends pay? +Answer: Silas paid half = 150/2 = <<150/2=75>>75 +Remaining bill paid by 5 friends = 75 + 10% of 150 = 75 + 15 = 90 +Each person will pay 1/5 which is 90/5 = <<90/5=18>>18 +Each friend will pay $<<18=18>>18. +#### 18 + +Question: Nancy takes 3 antacids per day when she eats Indian food, 2 antacids per day when she eats Mexican food, and 1 antacid per day otherwise. If Nancy eats Indian three times a week and Mexican twice a week, how many antacids does she take per month? +Answer: First find the total number of antacids Nancy takes after eating Indian food per week: 3 antacids/day * 3 days/week = <<3*3=9>>9 antacids/week +Then find the total number of antacids Nancy takes after eating Mexican food per week: 2 antacids/day * 2 days/week = <<2*2=4>>4 antacids/week +Then find the number of days she doesn't eat Indian food or Mexican food: 7 days/week - 3 days/week - 2 days/week = 2 days/week +Then find the total number of antacids Nancy takes per week: 9 antacids/week + 4 antacids/week + 2 antacids/week = <<9+4+2=15>>15 antacids/week +Then multiply her weekly antacid consumption by the number of weeks per month to find her monthly consumption: 15 antacids/week * 4 week/month = <<15*4=60>>60 antacids/month +#### 60 + +Question: Mr. and Mrs. Boyden take their 3 children to a leisure park. They buy tickets for the whole family. The cost of an adult ticket is $6 more than the cost of a child ticket. The total cost of the 5 tickets is $77. What is the cost of an adult ticket? +Answer: Let X be the cost of an adult ticket. +So the cost of a child ticket is X-6. +The total cost of the 5 tickets is X*2 + 3*(X-6) = 77. +X*2 + 3*X - 3*6 = 77. +5*X - 18 = 77. +5*X = 77 + 18 = 95 +X = <<19=19>>19 +#### 19 + +Question: Mark builds an apartment that is 16 by 10 feet. There are 6 rooms in total. All the rooms are the same size except the living room which is as big as 3 other rooms. How big is the living room? +Answer: Total square footage is 16*10=<<16*10=160>>160 square feet + +There are 3+3=<<3+3=6>>6 rooms +6-1=<<6-1=5>>5 of them are the same size +If x is the size of the normal room then the square footage of all rooms is 5x+3x=8x +So each room is 160/8=<<160/8=20>>20 square feet +So the living room is 20*3=<<20*3=60>>60 square feet +#### 60 + +Question: Adrien's total salary was 30 percent higher than Lylah's. Four years later, his salary had increased, and he was earning 40% more than what he was making four years ago. If Adrien's and Lylah's salary increased simultaneously, and Adrien earned $40000 four years ago, calculate the total salary the two were receiving four years later? +Answer:""" +text_poem = """以下是《长恨歌》的开头一部分,请帮助补充完整,直到结尾: +汉皇重色思倾国,御宇多年求不得。 +杨家有女初长成,养在深闺人未识。 +天生丽质难自弃,一朝选在君王侧。 +回眸一笑百媚生,六宫粉黛无颜色。 +春寒赐浴华清池,温泉水滑洗凝脂。 +侍儿扶起娇无力,始是新承恩泽时。 +云鬓花颜金步摇,芙蓉帐暖度春宵。 +春宵苦短日高起,从此君王不早朝。 +承欢侍宴无闲暇,春从春游夜专夜。 +后宫佳丽三千人,三千宠爱在一身。 +金屋妆成娇侍夜,玉楼宴罢醉和春。 +姊妹弟兄皆列土,可怜光彩生门户。 +遂令天下父母心,不重生男重生女。 +骊宫高处入青云,仙乐风飘处处闻。 +缓歌慢舞凝丝竹,尽日君王看不足。 +渔阳鼙鼓动地来,惊破霓裳羽衣曲。 +九重城阙烟尘生,千乘万骑西南行。 +翠华摇摇行复止,西出都门百余里。 +六军不发无奈何,宛转蛾眉马前死。 +花钿委地无人收,翠翘金雀玉搔头。 +君王掩面救不得,回看血泪相和流。 +黄埃散漫风萧索,云栈萦纡登剑阁。""" +b = """ +峨嵋山下少人行,旌旗无光日色薄。 +蜀江水碧蜀山青,圣主朝朝暮暮情。 +行宫见月伤心色,夜雨闻铃肠断声。 +天旋地转回龙驭,到此踌躇不能去。 +马嵬坡下泥土中,不见玉颜空死处。 +君臣相顾尽沾衣,东望都门信马归。 +归来池苑皆依旧,太液芙蓉未央柳。 +芙蓉如面柳如眉,对此如何不泪垂? +春风桃李花开日,秋雨梧桐叶落时。 +西宫南内多秋草,落叶满阶红不扫。 +梨园弟子白发新,椒房阿监青娥老。 +夕殿萤飞思悄然,孤灯挑尽未成眠。 +迟迟钟鼓初长夜,耿耿星河欲曙天。 +鸳鸯瓦冷霜华重,翡翠衾寒谁与共? +悠悠生死别经年,魂魄不曾来入梦。 +临邛道士鸿都客,能以精诚致魂魄。 +为感君王辗转思,遂教方士殷勤觅。 +排空驭气奔如电,升天入地求之遍。 +上穷碧落下黄泉,两处茫茫皆不见。 +忽闻海上有仙山,山在虚无缥缈间。 +楼阁玲珑五云起,其中绰约多仙子。 +中有一人字太真,雪肤花貌参差是。 +金阙西厢叩玉扃,转教小玉报双成。 +闻道汉家天子使,九华帐里梦魂惊。 +揽衣推枕起徘徊,珠箔银屏迤逦开。 +云鬓半偏新睡觉,花冠不整下堂来。 +风吹仙袂飘飖举,犹似霓裳羽衣舞。 +玉容寂寞泪阑干,梨花一枝春带雨。 +含情凝睇谢君王,一别音容两渺茫。 +昭阳殿里恩爱绝,蓬莱宫中日月长。 +回头下望人寰处,不见长安见尘雾。 +惟将旧物表深情,钿合金钗寄将去。 +钗留一股合一扇,钗擘黄金合分钿。""" + + +def create_linear_attention_metadata(): + """Create a LinearAttentionMetadata for recurrent-states linear attention.""" + metadata = LinearAttentionMetadata() + metadata.linear_layer_indices = [0, 1] + metadata.cache_type = LinearCacheType.RECURRENT_STATES.value + metadata.all_recurrent_states_bytes = 440 * 1024 # 440 KB + metadata.input_features_bytes_per_token = 0 + metadata.states_snapshot_interval = 96 + return metadata + + +def create_kv_cache_manager(kv_cache_config=None): + """Create a KVCacheManager using the Python wrapper.""" + num_layers = 6 + num_kv_heads = 2 + head_dim = 128 + tokens_per_block = 32 + max_seq_len = 1024 + max_batch_size = 7 + mapping = Mapping() + + # Load the HuggingFace PretrainedConfig and convert to C++ bindings ModelConfig + pretrained_config = Qwen3NextConfig.from_json_file( + "/home/scratch.trt_llm_data/llm-models/Qwen3-Next/Qwen3-Next-80B-A3B-Instruct/config.json" + ) + torch_model_config = ModelConfig( + pretrained_config=pretrained_config, + mapping=mapping, + ) + binding_model_config = torch_model_config.get_bindings_model_config( + tokens_per_block=tokens_per_block + ) + + return KVCacheManager( + kv_cache_config=KvCacheConfig( + free_gpu_memory_fraction=0.1, + max_tokens=8192, + enable_block_reuse=True, + max_attention_window=[max_seq_len, LinearCacheType.RECURRENT_STATES.value], + enable_partial_reuse=False, + ), + kv_cache_type=CacheType.SELF, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + linear_attention_metadata=create_linear_attention_metadata(), + model_config=binding_model_config, + ) + + +def create_llm_request(request_id, input_tokens, max_new_tokens=1): + """Helper to create an LlmRequest for testing.""" + sampling_params = SamplingParams() + return LlmRequest( + request_id=request_id, + max_new_tokens=max_new_tokens, + input_tokens=input_tokens, + sampling_config=tensorrt_llm.bindings.SamplingConfig( + sampling_params._get_sampling_config() + ), + is_streaming=False, + ) + + +def test_linear_attention_1batch(): + prompt_len = 256 + kv_cache_manager = create_kv_cache_manager() + try: + # Create an LlmRequest + req = create_llm_request( + request_id=0, + input_tokens=range(prompt_len), + ) + + # Add the sequence to the KV cache manager + kv_cache_manager.impl.add_sequence(req.py_request_id, prompt_len, 1, req) + + # Verify blocks were allocated + block_ids = kv_cache_manager.get_cache_indices(req, LinearCacheType.RECURRENT_STATES.value) + print(f"block_ids: {block_ids}") + block_ids = kv_cache_manager.get_cache_indices(req, kv_cache_manager.max_seq_len) + print(f"block_ids: {block_ids}") + # return + req.context_current_position = 0 + req.context_chunk_size = 96 + + block_idx = ( + (req.get_num_tokens(0) if req.is_context_finished else req.context_current_position) + - 1 + + req.context_chunk_size + ) // kv_cache_manager.tokens_per_block + host_linear_block_offsets = torch.zeros( + [ + kv_cache_manager.num_pools, + kv_cache_manager.max_batch_size, + 2, + kv_cache_manager.max_blocks_per_seq, + ], + dtype=torch.int32, + device="cpu", + ) + # input("Press Enter to continue...") + # kv_cache_manager.impl.copy_batch_block_offsets(host_kv_cache_block_offsets, [req.py_request_id], 1, 0) + # print(f"offsets: {host_kv_cache_block_offsets}") + kv_cache_manager.impl.copy_linear_batch_block_offsets( + host_linear_block_offsets, [req.py_request_id], 1, 0 + ) + print(f"offsets: {host_linear_block_offsets}") + + print(f"block_idx: {block_idx}") + batch0_current_block_offset = host_linear_block_offsets[0, 0, 0, block_idx] + print(f"batch0_current_block_offset: {batch0_current_block_offset}") + + pool0 = kv_cache_manager.impl.get_primary_pool_data(0) + pool1 = kv_cache_manager.impl.get_primary_pool_data(3) + print(f"pool0: {pool0.shape}, {pool0.stride()}") + print(f"pool1: {pool1.shape}, {pool1.stride()}") + + # pool_shape = [primary_block_num, kv_cache_manager.num_layers // 2, kv_cache_manager.linear_attention_metadata.all_recurrent_states_bytes] + # import ctypes + # buffer = (ctypes.c_uint8 * reduce(lambda x, y: x * y, pool_shape)).from_address(pool_base_addr) + # pool_as_tensor = torch.from_dlpack(buffer, device='cuda').view(pool_shape) + ssm_shape = [ + # 3, # num_layers + 2, # num_heads + 128, # head_dim + 128, # d_state (=head_dim for Qwen3-Next) + ] + torch_dtype = tensorrt_llm._utils.str_dtype_to_torch( + tensorrt_llm._utils.binding_to_str_dtype(kv_cache_manager.dtype) + ) + ssm_size = reduce(lambda x, y: x * y, ssm_shape) + # With layer-first pool layout, get_primary_pool_data returns per-layer data. + # pool1 shape: {numBlocks, kvFactor(1), blockSize} + pool_ssm_states = ( + pool1[:, 0, :ssm_size].view(torch_dtype).reshape([pool1.shape[0], *ssm_shape]) + ) + assert pool_ssm_states._is_view() + # batch0_current_block_offset is the block index directly (no num_layers factor) + my_ssm_states = pool_ssm_states[batch0_current_block_offset] + print(f"ssm_states: {my_ssm_states.shape}, {my_ssm_states.stride()}") + + # Add a generation token + # kv_cache_manager.impl.add_token(req.py_request_id) + + # Verify stats + # stats = kv_cache_manager.get_kv_cache_stats() + # assert stats.max_num_blocks > 0 + + # Clean up + # kv_cache_manager.free_resources(req) + finally: + kv_cache_manager.shutdown() + + +def test_linear_attention_multi_batch(): + prompt_len = 256 + kv_cache_manager = create_kv_cache_manager() + try: + num_requests = 4 + requests = [ + create_llm_request( + request_id=i, + input_tokens=range(prompt_len), + ) + for i in range(num_requests) + ] + # Create an LlmRequest + # Add the sequence to the KV cache manager + for req in requests: + kv_cache_manager.impl.add_sequence(req.py_request_id, prompt_len, 1, req) + + req.context_current_position = 0 + req.context_chunk_size = 96 + + block_idx = ( + (req.get_num_tokens(0) if req.is_context_finished else req.context_current_position) + - 1 + + req.context_chunk_size + ) // kv_cache_manager.tokens_per_block + host_linear_block_offsets = torch.zeros( + [ + kv_cache_manager.num_pools, + kv_cache_manager.max_batch_size, + 2, + kv_cache_manager.max_blocks_per_seq, + ], + dtype=torch.int32, + device="cpu", + ) + # input("Press Enter to continue...") + # kv_cache_manager.impl.copy_batch_block_offsets(host_kv_cache_block_offsets, [req.py_request_id], 1, 0) + # print(f"offsets: {host_kv_cache_block_offsets}") + kv_cache_manager.impl.copy_linear_batch_block_offsets( + host_linear_block_offsets, [0, 1, 2, 3], 1, 0 + ) + print(f"offsets: {host_linear_block_offsets}") + + print(f"block_idx: {block_idx}") + current_block_offset = host_linear_block_offsets[0, 0:num_requests, 0, block_idx] + print(f"current_block_offset: {current_block_offset}") + + pool0 = kv_cache_manager.impl.get_primary_pool_data(0) + pool1 = kv_cache_manager.impl.get_primary_pool_data(1) + print(f"pool0: {pool0.shape}, {pool0.stride()}") + print(f"pool1: {pool1.shape}, {pool1.stride()}") + + # pool_shape = [primary_block_num, kv_cache_manager.num_layers // 2, kv_cache_manager.linear_attention_metadata.all_recurrent_states_bytes] + # import ctypes + # buffer = (ctypes.c_uint8 * reduce(lambda x, y: x * y, pool_shape)).from_address(pool_base_addr) + # pool_as_tensor = torch.from_dlpack(buffer, device='cuda').view(pool_shape) + ssm_shape = [ + # 3, # num_layers + 2, # num_heads + 128, # head_dim + 128, # d_state (=head_dim for Qwen3-Next) + ] + torch_dtype = tensorrt_llm._utils.str_dtype_to_torch( + tensorrt_llm._utils.binding_to_str_dtype(kv_cache_manager.dtype) + ) + ssm_size = reduce(lambda x, y: x * y, ssm_shape) + pool_ssm_states = ( + pool1[:, 0, :ssm_size].view(torch_dtype).reshape([pool1.shape[0], *ssm_shape]) + ) + assert pool_ssm_states._is_view() + my_ssm_states = pool_ssm_states[current_block_offset] + print(f"ssm_states: {my_ssm_states.shape}, {my_ssm_states.stride()}") + assert my_ssm_states._is_view() + + # Add a generation token + # kv_cache_manager.impl.add_token(req.py_request_id) + + # Verify stats + # stats = kv_cache_manager.get_kv_cache_stats() + # assert stats.max_num_blocks > 0 + + # Clean up + # kv_cache_manager.free_resources(req) + finally: + kv_cache_manager.shutdown() + + +def test_qwen3_next_with_reuse(): + max_batch_size = 1 + # model_path = f"/home/scratch.trt_llm_data/llm-models/Qwen3-Next/qwen3-next-80b-instruct-nvfp4-ptq-fp8kv" + model_path = "/home/scratch.trt_llm_data/llm-models/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4/" + + kv_cache_config = KvCacheConfig( + free_gpu_memory_fraction=0.4, + max_tokens=163840, + # mamba_ssm_cache_dtype="float16", + mamba_state_cache_interval=128, + enable_block_reuse=True, + ) + pytorch_config = dict( + disable_overlap_scheduler=True, + max_batch_size=max_batch_size, + enable_chunked_prefill=True, + cuda_graph_config=None, + # CudaGraphConfig(max_batch_size=256, enable_padding=True) + ) + moe_config = MoeConfig(backend="TRTLLM") + + # inputs = [input] * max_batch_size * 1 + inputs = [text_poem] + with tensorrt_llm.LLM( + model_path, + tensor_parallel_size=1, + scheduler_config=SchedulerConfig(use_python_scheduler=True), + max_num_tokens=4096, + pipeline_parallel_size=1, + moe_expert_parallel_size=1, + kv_cache_config=kv_cache_config, + **pytorch_config, + moe_config=moe_config, + speculative_config=MTPDecodingConfig( + num_nextn_predict_layers=3, + mtp_eagle_one_model=True, + ), + ) as llm: + result1 = llm.generate(inputs, SamplingParams(max_tokens=200)) + result2 = llm.generate(inputs, SamplingParams(max_tokens=200)) + for i in range(len(inputs)): + print(result1[i].outputs[0].text) + print(result2[i].outputs[0].text) + print("--------------------------------") + + +if __name__ == "__main__": + test_qwen3_next_with_reuse() diff --git a/tests/unittest/disaggregated/test_mamba_transfer.py b/tests/unittest/disaggregated/test_mamba_transfer.py index f637b0e6f14..5be3e43e7d4 100644 --- a/tests/unittest/disaggregated/test_mamba_transfer.py +++ b/tests/unittest/disaggregated/test_mamba_transfer.py @@ -25,7 +25,7 @@ from tensorrt_llm import DisaggregatedParams, Mapping, SamplingParams from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, LlmRequestType -from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import MambaHybridCacheManager +from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import MixedMambaHybridCacheManager from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests from tensorrt_llm.bindings import DataType from tensorrt_llm.bindings.internal.batch_manager import CacheType as CacheTypeCpp @@ -180,7 +180,7 @@ def _init(rank): def _create_managers(tp): - """Create MambaHybridCacheManagers for all TP ranks (PP=1). + """Create MixedMambaHybridCacheManagers for all TP ranks (PP=1). Layer 0 is a dummy attention layer required by page table infrastructure. Layers 1..NUM_MAMBA_LAYERS are mamba layers under test. @@ -188,7 +188,7 @@ def _create_managers(tp): managers = [] for rank in range(tp): mapping = Mapping(world_size=tp, rank=rank, tp_size=tp, pp_size=1) - mgr = MambaHybridCacheManager( + mgr = MixedMambaHybridCacheManager( mamba_d_state=MAMBA_D_STATE, mamba_d_conv=MAMBA_D_CONV, mamba_num_heads=MAMBA_NUM_HEADS,