diff --git a/.claude/README.md b/.claude/README.md new file mode 100644 index 00000000000..4b42a9e69e2 --- /dev/null +++ b/.claude/README.md @@ -0,0 +1,82 @@ +# Custom Claude Code Skills & Agents for TensorRT-LLM + +## Background: Skills & agents in Claude Code + +Claude Code supports two extensibility mechanisms — **skills** and **agents** — +that let teams encode domain expertise into reusable, version-controlled +components. + +**Skills** are markdown playbooks that Claude follows step-by-step when +triggered. They are invoked via `/slash-commands` (e.g. `/perf-analysis`) or +matched automatically from natural-language requests. Each skill lives in its +own directory under `.claude/skills/` and can bundle reference materials that +Claude reads during execution. See +[Custom slash commands](https://code.claude.com/docs/en/skills) +for details. + +**Agents** (sub-agents) are specialist workers that Claude spawns in a separate +context to handle focused tasks. Each agent has its own system prompt, tool +access, and domain knowledge. Claude delegates to them when it determines a task +fits a specialist's scope, while you can also invoke agents directly. Agent +definitions live under `.claude/agents/`. See +[Custom sub-agents](https://code.claude.com/docs/en/sub-agents) +for details. + +## How skills and agents are loaded + +For users who are working with Claude Code under TensorRT-LLM project directory, +skills and agents are automatically discovered by Claude Code at startup — no +manual registration needed. Files placed in `.claude/skills/` and +`.claude/agents/` are picked up by convention. + +To verify what's loaded, launch Claude Code under TensorRT-LLM project directory +and type `/skills` or `/agents` in the Claude Code prompt to see available +skills and sub-agents. + +## How to use skills and agents + +There are two ways to trigger skills and agents: + +1. **Automatic dispatch** — just describe what you need in plain language + (e.g. "profile this workload", "compile TensorRT-LLM"). Claude Code will + match your request to the appropriate skill or delegate to the right + sub-agent automatically. + +2. **Manual invoke** — type `/` (e.g. `/perf-analysis`, + `/serve-config-guide`) to explicitly run a skill. For sub-agents, type + `@"" (agent)` (e.g. `@"exec-compile-specialist (agent)"`) to + delegate a task directly. This is useful when you know exactly which workflow you want. + +In most cases, automatic dispatch is sufficient — you don't need to memorize +skill or agent names. Manual invoke is there for when you want precise control. + +References: +* [Extend Claude with skills](https://code.claude.com/docs/en/skills) +* [Work with subagents](https://code.claude.com/docs/en/sub-agents#work-with-subagents) + +## Naming convention + +Every skill and agent name uses the format `-`. +The prefix identifies the primary work area; the descriptive part should be +short and not repeat it. + +| Prefix | Domain | Definition | +|---|---|---| +| `ad-` | AutoDeploy | Model onboarding, pipeline debugging, and execution for the AutoDeploy backend | +| `ci-` | CI/CD | CI failure retrieval, test diagnostics, and pipeline workflows | +| `exec-` | Execution infra | Environment setup and job execution (compile, run, container) | +| `kernel-` | Kernel development | Kernel writing, generation, and kernel-specific transforms | +| `perf-` | Performance work | Profiling, analysis, and tuning above the kernel layer (kernel modifications belong under `kernel-`) | +| `serve-` | Serving | Serving configuration, deployment, and runtime workflows | +| `trtllm-` | TRT-LLM dev workflows | Workflows for reading, modifying, and contributing to the codebase (static subsystem knowledge belongs in repo docs) | + +Guidelines: + +* If a skill doesn't fit any prefix, propose a new one and agree on its + boundary before using it. +* Use the prefix of the skill's **primary** domain, even if it orchestrates + across multiple domains. +* Agents follow the same convention. +* Good: `exec-local-compile`, `kernel-cuda-writing`, `perf-host-analysis` +* Bad: `exec-trtllm-compile`, `kernel-cuda-kernel-writing`, + `perf-trtllm-host-analysis` diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 06a864511e0..8f37eefe118 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -259,6 +259,42 @@ struct KvCacheStats std::size_t allocatedBytes{}; }; +/// @brief Per-iteration KV cache statistics. All delta counters represent changes since the last call to +/// getIterationStats(). Gauges are instantaneous snapshots. +struct KvCacheIterationStats +{ + // --- Instantaneous gauges --- + // Primary (GPU) pool + SizeType32 primaryMaxNumBlocks{0}; + SizeType32 primaryFreeNumBlocks{0}; + SizeType32 primaryUsedNumBlocks{0}; + // Secondary (host) pool + SizeType32 secondaryMaxNumBlocks{0}; + SizeType32 secondaryFreeNumBlocks{0}; + SizeType32 secondaryUsedNumBlocks{0}; + + // --- Per-iteration deltas (reset on each read) --- + // Context phase: block allocation and reuse + SizeType32 iterAllocTotalBlocks{0}; + SizeType32 iterAllocNewBlocks{0}; + SizeType32 iterReusedBlocks{0}; // = iterFullReusedBlocks + iterPartialReusedBlocks + SizeType32 iterFullReusedBlocks{0}; // blocks fully matched in radix tree + SizeType32 iterPartialReusedBlocks{0}; // blocks partially matched in radix tree + SizeType32 iterMissedBlocks{0}; + float iterCacheHitRate{0.0f}; + // Generation phase: block allocation + SizeType32 iterGenAllocBlocks{0}; + + // Transfer traffic deltas — host ↔ GPU + SizeType32 iterOnboardBlocks{0}; + std::size_t iterOnboardBytes{0}; + SizeType32 iterOffloadBlocks{0}; + std::size_t iterOffloadBytes{0}; + // Intra-device (GPU → GPU) block copies (e.g. partial reuse when source block has refs) + SizeType32 iterIntraDeviceCopyBlocks{0}; + std::size_t iterIntraDeviceCopyBytes{0}; +}; + // Basic building block of a paged KV cache - a single // cache block. This class just holds metadata, no pointers // since it is reused across all layers. @@ -815,6 +851,12 @@ class WindowBlockManager return mMissedBlocks; } + // Get num free blocks in the secondary (host) memory pool + [[nodiscard]] SizeType32 getNumFreeSecondaryBlocks() const noexcept; + + //! \brief Get iteration stats (deltas since last call) for this window. Resets internal delta snapshots. + [[nodiscard]] KvCacheIterationStats getAndResetIterationStats(); + [[nodiscard]] bool hasFreeBlocks(SizeType32 numRequired = 1) const { return getNumFreeBlocks() >= numRequired; @@ -1128,16 +1170,22 @@ class WindowBlockManager std::shared_ptr mTransferManager; // Statistics for block allocations/reuse - // Total number of blocks allocated by all requests + // Total number of blocks allocated by all requests (context phase) SizeType32 mAllocTotalBlocks; - // Number of new blocks that were allocated + // Number of new blocks that were allocated (context phase) SizeType32 mAllocNewBlocks; - // Number of blocks that were reused + // Number of blocks that were fully reused (context phase) + SizeType32 mFullReusedBlocks; + // Number of blocks that were partially reused (context phase) + SizeType32 mPartialReusedBlocks; + // Number of blocks that were reused (full + partial, context phase) SizeType32 mReusedBlocks; // Number of unique blocks that were reused SizeType32 mReusedUniqueBlocks; - // Number of blocks that were not reused + // Number of blocks that were not reused (context phase) SizeType32 mMissedBlocks; + // Number of blocks allocated during generation phase + SizeType32 mGenAllocBlocks; // Only be 1 or 2. If 2: general KV stored. If 1: K == V for any token, so only K is stored to optimize the // max_num_tokens(For DeepSeek). Controlled by mCacheType SizeType32 mKVFactor; @@ -1154,6 +1202,15 @@ class WindowBlockManager // The kv cache connector manager std::shared_ptr mKvCacheConnectorManager; + // Snapshot of cumulative counters at last iteration stats read (for delta computation) + SizeType32 mPrevAllocTotalBlocks{0}; + SizeType32 mPrevAllocNewBlocks{0}; + SizeType32 mPrevReusedBlocks{0}; + SizeType32 mPrevFullReusedBlocks{0}; + SizeType32 mPrevPartialReusedBlocks{0}; + SizeType32 mPrevMissedBlocks{0}; + SizeType32 mPrevGenAllocBlocks{0}; + // Mutex for the cached blocks root mutable std::mutex mCachedBlocksRootMutex; @@ -1359,6 +1416,19 @@ class BlockManager return sumWindows([](auto const& manager) { return manager.getNumMissedBlocks(); }); } + [[nodiscard]] SizeType32 getNumSecondaryBlocks() const + { + return sumWindows([](auto const& manager) { return manager.getNumSecondaryBlocks(); }); + } + + [[nodiscard]] SizeType32 getNumFreeSecondaryBlocks() const + { + return sumWindows([](auto const& manager) { return manager.getNumFreeSecondaryBlocks(); }); + } + + //! \brief Get per-window-size iteration stats. Resets delta snapshots for each window. + [[nodiscard]] std::map getAndResetIterationStats(); + [[nodiscard]] SizeType32 getNumLayers() const { return mNumLayers; @@ -1688,6 +1758,10 @@ class BaseKVCacheManager [[nodiscard]] virtual KvCacheStats getKvCacheStats() const = 0; + //! \brief Get per-iteration stats with delta counters, keyed by window size. + //! Resets delta snapshots on each call. + [[nodiscard]] virtual std::map getIterationStats() = 0; + [[nodiscard]] virtual OffsetTableDimensions getOffsetTableDimensions() const = 0; [[nodiscard]] virtual std::deque getLatestEvents( @@ -2046,6 +2120,11 @@ class KVCacheManager : public BaseKVCacheManager return kvCacheStats; } + [[nodiscard]] std::map getIterationStats() override + { + return mBlockManager.getAndResetIterationStats(); + } + [[nodiscard]] OffsetTableDimensions getOffsetTableDimensions() const override { OffsetTableDimensions dims; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h index 00540dc671e..8c40a46045d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h @@ -27,6 +27,20 @@ namespace kvc = tensorrt_llm::executor::kv_cache; namespace tensorrt_llm::batch_manager::kv_cache_manager { +/// @brief Statistics for block transfers. Returned by KVCacheTransferManager::getAndResetTransferStats(). +/// All counters are reset on read. +/// - onboard/offload: transfers between secondary (host) and primary (GPU) memory. +/// - intraDeviceCopy: GPU-to-GPU block copies (e.g. partial reuse when source block has refs). +struct KvCacheTransferStats +{ + SizeType32 onboardBlocks{0}; + std::size_t onboardBytes{0}; + SizeType32 offloadBlocks{0}; + std::size_t offloadBytes{0}; + SizeType32 intraDeviceCopyBlocks{0}; + std::size_t intraDeviceCopyBytes{0}; +}; + // The TransferManager accelerates transfers to/from the GPU by overlapping HtoD and DtoH transfers, and tracks ongoing // transfers in order to avoid race conditions. It is functionally equivalent to the prior approach of putting all // transfers into the forward pass stream. This is only ever used as a component of a KVCacheManager. @@ -57,6 +71,9 @@ class KVCacheTransferManager //! must be called after last call to KVCacheManager::addSequence in every step. void syncTransfers(); + //! \brief Get transfer stats accumulated since last call, and reset the counters. + [[nodiscard]] KvCacheTransferStats getAndResetTransferStats(); + private: //! \brief Get pointer to pool specified by cache block. static tr::ITensor::SharedPtr computeBlockPointer( @@ -79,6 +96,12 @@ class KVCacheTransferManager int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); + //! \brief Compute total bytes actually transferred for a block copy across all pools. + //! \param pools The pool descriptors. + //! \param numTokensToCopy Number of tokens for partial copy (0 means full block). + [[nodiscard]] std::size_t computeBlockTransferBytes( + std::vector const& pools, int numTokensToCopy) const; + runtime::BufferManager mBufferManager; runtime::BufferManager mOnboardManager; runtime::BufferManager mOffloadManager; @@ -90,6 +113,16 @@ class KVCacheTransferManager // Reference to parent loopback agent std::shared_ptr mLoopbackAgent; int mDeviceId; + + // Cumulative transfer statistics, reset on each call to getAndResetTransferStats(). + // Protected by mStatsMutex for thread-safe access. + mutable std::mutex mStatsMutex; + SizeType32 mOnboardBlockCount{0}; + std::size_t mOnboardByteCount{0}; + SizeType32 mOffloadBlockCount{0}; + std::size_t mOffloadByteCount{0}; + SizeType32 mIntraDeviceCopyBlockCount{0}; + std::size_t mIntraDeviceCopyByteCount{0}; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index a10b42fbd86..b88f09a0471 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -760,9 +760,12 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mTransferManager{std::make_shared(mBufferManager, mLoopbackAgent)} , mAllocTotalBlocks{0} , mAllocNewBlocks{0} + , mFullReusedBlocks{0} + , mPartialReusedBlocks{0} , mReusedBlocks{0} , mReusedUniqueBlocks{0} , mMissedBlocks{0} + , mGenAllocBlocks{0} , mKVFactor{(mCacheType == CacheType::kSELFKONLY || (linearAttentionMetadata.has_value() && linearAttentionMetadata->hasRecurrentStatesCache())) ? 1 @@ -1518,6 +1521,14 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& reusedBlockIds.insert(matchingBlockId); ++mReusedUniqueBlocks; } + if (partialMatch) + { + ++mPartialReusedBlocks; + } + else + { + ++mFullReusedBlocks; + } } ++blockItr; } @@ -1726,6 +1737,7 @@ void WindowBlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence) { // Allocating a new block when the last token is a block boundary allocateBlock(sequence, /*shareAmongBeams=*/sequence.getBeamWidth() == 1); + ++mGenAllocBlocks; updateLastCacheBlockOffsets(sequence); } } @@ -2226,6 +2238,73 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence) return numFree; } +[[nodiscard]] SizeType32 WindowBlockManager::getNumFreeSecondaryBlocks() const noexcept +{ + return mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel); +} + +KvCacheIterationStats WindowBlockManager::getAndResetIterationStats() +{ + std::lock_guard lock(mCachedBlocksRootMutex); + KvCacheIterationStats stats; + + // Instantaneous gauges + stats.primaryMaxNumBlocks = getNumPrimaryBlocks(); + stats.primaryFreeNumBlocks = getNumFreeBlocks(); + stats.primaryUsedNumBlocks = stats.primaryMaxNumBlocks - stats.primaryFreeNumBlocks; + stats.secondaryMaxNumBlocks = getNumSecondaryBlocks(); + stats.secondaryFreeNumBlocks = getNumFreeSecondaryBlocks(); + stats.secondaryUsedNumBlocks = stats.secondaryMaxNumBlocks - stats.secondaryFreeNumBlocks; + + // Compute deltas since last call — context phase + stats.iterAllocTotalBlocks = mAllocTotalBlocks - mPrevAllocTotalBlocks; + stats.iterAllocNewBlocks = mAllocNewBlocks - mPrevAllocNewBlocks; + stats.iterReusedBlocks = mReusedBlocks - mPrevReusedBlocks; + stats.iterFullReusedBlocks = mFullReusedBlocks - mPrevFullReusedBlocks; + stats.iterPartialReusedBlocks = mPartialReusedBlocks - mPrevPartialReusedBlocks; + stats.iterMissedBlocks = mMissedBlocks - mPrevMissedBlocks; + + auto const iterTotal = stats.iterReusedBlocks + stats.iterMissedBlocks; + stats.iterCacheHitRate + = iterTotal == 0 ? 0.0f : static_cast(stats.iterReusedBlocks) / static_cast(iterTotal); + + // Generation phase + stats.iterGenAllocBlocks = mGenAllocBlocks - mPrevGenAllocBlocks; + + // Snapshot current values for next delta + mPrevAllocTotalBlocks = mAllocTotalBlocks; + mPrevAllocNewBlocks = mAllocNewBlocks; + mPrevReusedBlocks = mReusedBlocks; + mPrevFullReusedBlocks = mFullReusedBlocks; + mPrevPartialReusedBlocks = mPartialReusedBlocks; + mPrevMissedBlocks = mMissedBlocks; + mPrevGenAllocBlocks = mGenAllocBlocks; + + // Transfer stats (collected from transfer manager) + if (mTransferManager) + { + auto transferStats = mTransferManager->getAndResetTransferStats(); + stats.iterOnboardBlocks = transferStats.onboardBlocks; + stats.iterOnboardBytes = transferStats.onboardBytes; + stats.iterOffloadBlocks = transferStats.offloadBlocks; + stats.iterOffloadBytes = transferStats.offloadBytes; + stats.iterIntraDeviceCopyBlocks = transferStats.intraDeviceCopyBlocks; + stats.iterIntraDeviceCopyBytes = transferStats.intraDeviceCopyBytes; + } + + return stats; +} + +std::map BlockManager::getAndResetIterationStats() +{ + std::map perWindowStats; + for (auto& [windowSize, manager] : mWindowBlockManagers) + { + perWindowStats[windowSize] = manager.getAndResetIterationStats(); + } + return perWindowStats; +} + std::deque BlockManager::getLatestEvents(std::optional timeout) const { return mEventManager ? mEventManager->getEvents(timeout) : std::deque{}; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp index 005ee9fa60d..9b5b71377b7 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp @@ -297,6 +297,22 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr co copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory); + // Update transfer statistics — distinguish host→GPU onboard from GPU→GPU intra-device copy + { + std::lock_guard lock(mStatsMutex); + auto bytes = computeBlockTransferBytes(pools, numTokensToCopy); + if (offloadedBlock->isPrimary()) + { + ++mIntraDeviceCopyBlockCount; + mIntraDeviceCopyByteCount += bytes; + } + else + { + ++mOnboardBlockCount; + mOnboardByteCount += bytes; + } + } + // Record new pending read from offloadedBlock mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]); @@ -333,6 +349,13 @@ void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offl copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory); + // Update transfer statistics + { + std::lock_guard lock(mStatsMutex); + ++mOffloadBlockCount; + mOffloadByteCount += computeBlockTransferBytes(pools, numTokensToCopy); + } + // Record new pending read from block mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]); @@ -371,4 +394,60 @@ void KVCacheTransferManager::syncTransfers() mPendingWrites.clear(); } +KvCacheTransferStats KVCacheTransferManager::getAndResetTransferStats() +{ + std::lock_guard lock(mStatsMutex); + KvCacheTransferStats stats; + stats.onboardBlocks = mOnboardBlockCount; + stats.onboardBytes = mOnboardByteCount; + stats.offloadBlocks = mOffloadBlockCount; + stats.offloadBytes = mOffloadByteCount; + stats.intraDeviceCopyBlocks = mIntraDeviceCopyBlockCount; + stats.intraDeviceCopyBytes = mIntraDeviceCopyByteCount; + mOnboardBlockCount = 0; + mOnboardByteCount = 0; + mOffloadBlockCount = 0; + mOffloadByteCount = 0; + mIntraDeviceCopyBlockCount = 0; + mIntraDeviceCopyByteCount = 0; + return stats; +} + +std::size_t KVCacheTransferManager::computeBlockTransferBytes( + std::vector const& pools, int numTokensToCopy) const +{ + std::size_t totalBytes = 0; + for (auto const& pool : pools) + { + if (!pool.primaryPtr) + { + continue; + } + + auto const dataType = pool.primaryPtr->getDataType(); + auto const bytesPerElement + = pool.primaryPtr->getSizeInBytes() / static_cast(pool.primaryPtr->getSize()); + + // Mirror the logic in copyBlock: a partial copy only happens when numTokensToCopy > 0, + // the data type supports it (not kINT4/kFP4), not block scales, and numTokensToCopy < tokensPerBlock. + bool const isPartialCopy = numTokensToCopy > 0 && dataType != nvinfer1::DataType::kINT4 + && dataType != nvinfer1::DataType::kFP4 && !pool.containsBlockScales + && numTokensToCopy < pool.tokensPerBlock; + + if (isPartialCopy) + { + // Partial copy transfers: numLayers * kvFactor * numKvHeads * sizePerHead * numTokensToCopy elements + totalBytes += static_cast(pool.numLayers) * pool.kvFactor * pool.numKvHeads * pool.sizePerHead + * numTokensToCopy * bytesPerElement; + } + else + { + // Full block copy: numLayers * kvFactor * blockSize elements + // where blockSize = numKvHeads * sizePerHead * tokensPerBlock + totalBytes += static_cast(pool.numLayers) * pool.kvFactor * pool.blockSize * bytesPerElement; + } + } + return totalBytes; +} + } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index fe7687aee77..4ba5a3fbbaa 100755 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -106,6 +106,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(getKvCacheStats); } + std::map getIterationStats() override + { + NB_OVERRIDE_PURE(getIterationStats); + } + void addToken(tb::LlmRequest::RequestIdType requestId) override { NB_OVERRIDE_PURE(addToken, requestId); @@ -346,6 +351,29 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize) .def_rw("allocated_bytes", &tbk::KvCacheStats::allocatedBytes); + nb::class_(m, "KvCacheIterationStats") + .def(nb::init<>()) + .def_rw("primary_max_num_blocks", &tbk::KvCacheIterationStats::primaryMaxNumBlocks) + .def_rw("primary_free_num_blocks", &tbk::KvCacheIterationStats::primaryFreeNumBlocks) + .def_rw("primary_used_num_blocks", &tbk::KvCacheIterationStats::primaryUsedNumBlocks) + .def_rw("secondary_max_num_blocks", &tbk::KvCacheIterationStats::secondaryMaxNumBlocks) + .def_rw("secondary_free_num_blocks", &tbk::KvCacheIterationStats::secondaryFreeNumBlocks) + .def_rw("secondary_used_num_blocks", &tbk::KvCacheIterationStats::secondaryUsedNumBlocks) + .def_rw("iter_alloc_total_blocks", &tbk::KvCacheIterationStats::iterAllocTotalBlocks) + .def_rw("iter_alloc_new_blocks", &tbk::KvCacheIterationStats::iterAllocNewBlocks) + .def_rw("iter_reused_blocks", &tbk::KvCacheIterationStats::iterReusedBlocks) + .def_rw("iter_full_reused_blocks", &tbk::KvCacheIterationStats::iterFullReusedBlocks) + .def_rw("iter_partial_reused_blocks", &tbk::KvCacheIterationStats::iterPartialReusedBlocks) + .def_rw("iter_missed_blocks", &tbk::KvCacheIterationStats::iterMissedBlocks) + .def_rw("iter_cache_hit_rate", &tbk::KvCacheIterationStats::iterCacheHitRate) + .def_rw("iter_gen_alloc_blocks", &tbk::KvCacheIterationStats::iterGenAllocBlocks) + .def_rw("iter_onboard_blocks", &tbk::KvCacheIterationStats::iterOnboardBlocks) + .def_rw("iter_onboard_bytes", &tbk::KvCacheIterationStats::iterOnboardBytes) + .def_rw("iter_offload_blocks", &tbk::KvCacheIterationStats::iterOffloadBlocks) + .def_rw("iter_offload_bytes", &tbk::KvCacheIterationStats::iterOffloadBytes) + .def_rw("iter_intra_device_copy_blocks", &tbk::KvCacheIterationStats::iterIntraDeviceCopyBlocks) + .def_rw("iter_intra_device_copy_bytes", &tbk::KvCacheIterationStats::iterIntraDeviceCopyBytes); + nb::class_(m, "TempAttentionWindowInputs") .def(nb::init<>()) .def_rw("paged_context_fmha", &tbk::TempAttentionWindowInputs::pagedContextFMHA) @@ -384,6 +412,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def_prop_ro("max_num_blocks", &BaseKVCacheManager::getMaxNumBlocks) .def_prop_ro("num_pools", &BaseKVCacheManager::getNumPools) .def("get_kv_cache_stats", &BaseKVCacheManager::getKvCacheStats, nb::call_guard()) + .def("get_iteration_stats", &BaseKVCacheManager::getIterationStats, nb::call_guard()) .def_prop_ro("max_blocks_per_seq", [](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; }) .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep, diff --git a/scripts/visualize_kv_cache_stats.py b/scripts/visualize_kv_cache_stats.py new file mode 100644 index 00000000000..f48f67b3f8c --- /dev/null +++ b/scripts/visualize_kv_cache_stats.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Visualize per-iteration KV cache statistics from a JSON file. + +Produced by ``test_auto_dtype_vswa_reuse_kv_cache_stats`` (or any test that +writes the same schema). + +Usage: + python scripts/visualize_kv_cache_stats.py kv_cache_stats_output/kv_cache_stats_*.json + python scripts/visualize_kv_cache_stats.py stats.json --output charts.png + python scripts/visualize_kv_cache_stats.py stats.json --per-window +""" + +import argparse +import json +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +# --------------------------------------------------------------------------- +# Data extraction +# --------------------------------------------------------------------------- +def load_stats(path: str) -> dict: + with open(path) as f: + return json.load(f) + + +def extract_kv_entries(payload: dict, per_window: bool = False): + """Return a list of dicts, one per iteration with kvCacheIterationStats. + + Each dict has the stat fields plus ``_iter`` (global iteration index), + ``_phase``, and ``_collectedAt``. + + When *per_window* is False the fields are aggregated (summed for + counters, averaged for rates, maxed for gauges) across window sizes. + When True a ``_windowSize`` key is added and each window size + produces its own row. + """ + rows = [] + for entry in payload.get("stats", []): + kv = entry.get("kvCacheIterationStats") + if not kv: + continue + iteration = entry.get("iter", len(rows)) + phase = entry.get("_phase", "") + collected_at = entry.get("_collectedAt", 0) + + if per_window: + for ws, fields in kv.items(): + row = dict(fields) + row["_iter"] = iteration + row["_phase"] = phase + row["_collectedAt"] = collected_at + row["_windowSize"] = int(ws) + rows.append(row) + else: + # Aggregate across window sizes + agg = _aggregate_windows(kv) + agg["_iter"] = iteration + agg["_phase"] = phase + agg["_collectedAt"] = collected_at + rows.append(agg) + return rows + + +_GAUGE_FIELDS = { + "primaryMaxNumBlocks", + "primaryFreeNumBlocks", + "primaryUsedNumBlocks", + "secondaryMaxNumBlocks", + "secondaryFreeNumBlocks", + "secondaryUsedNumBlocks", +} +_RATE_FIELDS = {"iterCacheHitRate"} + + +def _aggregate_windows(kv: dict) -> dict: + """Aggregate stats across window sizes.""" + agg = {} + n = len(kv) + if n == 0: + return agg + for ws, fields in kv.items(): + for k, v in fields.items(): + if k in _GAUGE_FIELDS: + agg[k] = max(agg.get(k, 0), v) + elif k in _RATE_FIELDS: + agg[k] = agg.get(k, 0) + v + else: + agg[k] = agg.get(k, 0) + v + # Average the rate fields + for k in _RATE_FIELDS: + if k in agg: + agg[k] /= n + return agg + + +def _field_series(rows, field): + return np.array([r.get(field, 0) for r in rows], dtype=np.float64) + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- +def plot_all(rows, title_prefix: str = "", per_window: bool = False): + """Create a multi-panel figure with KV cache diagnostics.""" + if not rows: + print("No kvCacheIterationStats entries found.", file=sys.stderr) + sys.exit(1) + + iters = np.arange(len(rows)) + + # Detect phase boundaries for vertical lines + phases = [r["_phase"] for r in rows] + phase_boundaries = [] + for i in range(1, len(phases)): + if phases[i] != phases[i - 1]: + phase_boundaries.append((i, phases[i])) + + fig, axes = plt.subplots(8, 1, figsize=(14, 32), sharex=True) + fig.suptitle(f"{title_prefix}KV Cache Iteration Statistics", fontsize=14, y=0.98) + + def _add_phase_markers(ax): + for idx, label in phase_boundaries: + ax.axvline(idx, color="grey", linestyle="--", linewidth=0.8, alpha=0.6) + ax.text( + idx, ax.get_ylim()[1], f" {label}", fontsize=7, va="top", ha="left", color="grey" + ) + + # --- Panel 1: GPU Pool Utilization --- + ax = axes[0] + max_blocks = _field_series(rows, "primaryMaxNumBlocks") + used_blocks = _field_series(rows, "primaryUsedNumBlocks") + free_blocks = _field_series(rows, "primaryFreeNumBlocks") + utilization = np.where(max_blocks > 0, used_blocks / max_blocks, 0) + + ax.fill_between(iters, utilization, alpha=0.3, color="tab:blue", label="GPU utilization") + ax.plot(iters, utilization, color="tab:blue", linewidth=1) + ax.set_ylabel("GPU Pool Utilization") + ax.set_ylim(0, 1.05) + ax.legend(loc="upper right", fontsize=8) + ax.set_title("Primary (GPU) Pool Utilization", fontsize=10) + _add_phase_markers(ax) + + # Annotate block counts on secondary y-axis + ax2 = ax.twinx() + ax2.plot( + iters, + used_blocks, + color="tab:orange", + linewidth=0.8, + linestyle=":", + alpha=0.7, + label="used blocks", + ) + ax2.plot( + iters, + free_blocks, + color="tab:green", + linewidth=0.8, + linestyle=":", + alpha=0.7, + label="free blocks", + ) + ax2.set_ylabel("Block Count") + ax2.legend(loc="upper left", fontsize=7) + + # --- Panel 2: Block Reuse Breakdown --- + ax = axes[1] + full_reuse = _field_series(rows, "iterFullReusedBlocks") + partial_reuse = _field_series(rows, "iterPartialReusedBlocks") + missed = _field_series(rows, "iterMissedBlocks") + + ax.bar(iters, full_reuse, label="Full Reuse", color="tab:green", alpha=0.8, width=1.0) + ax.bar( + iters, + partial_reuse, + bottom=full_reuse, + label="Partial Reuse", + color="tab:orange", + alpha=0.8, + width=1.0, + ) + ax.bar( + iters, + missed, + bottom=full_reuse + partial_reuse, + label="Missed", + color="tab:red", + alpha=0.8, + width=1.0, + ) + ax.set_ylabel("Blocks") + ax.legend(loc="upper right", fontsize=8) + ax.set_title("Block Reuse Breakdown (per iteration)", fontsize=10) + _add_phase_markers(ax) + + # --- Panel 3: Cache Hit Rate --- + ax = axes[2] + hit_rate = _field_series(rows, "iterCacheHitRate") + ax.plot(iters, hit_rate, color="tab:purple", linewidth=1) + ax.fill_between(iters, hit_rate, alpha=0.2, color="tab:purple") + ax.set_ylabel("Cache Hit Rate") + ax.set_ylim(0, max(1.05, hit_rate.max() * 1.1) if hit_rate.max() > 0 else 1.05) + ax.set_title("Per-Iteration Cache Hit Rate", fontsize=10) + _add_phase_markers(ax) + + # --- Panel 4: Context-Phase Allocation --- + ax = axes[3] + alloc_total = _field_series(rows, "iterAllocTotalBlocks") + ax.plot(iters, alloc_total, label="AllocTotal", color="tab:blue", linewidth=1) + ax.fill_between(iters, alloc_total, alpha=0.2, color="tab:blue") + ax.set_ylabel("Blocks") + ax.legend(loc="upper right", fontsize=8) + ax.set_title("Context-Phase Allocation (per iteration)", fontsize=10) + _add_phase_markers(ax) + + # --- Panel 5: Generation-Phase Allocation --- + ax = axes[4] + alloc_new = _field_series(rows, "iterAllocNewBlocks") + gen_alloc = _field_series(rows, "iterGenAllocBlocks") + ax.plot(iters, gen_alloc, label="GenAlloc", color="tab:brown", linewidth=1) + ax.plot(iters, alloc_new, label="AllocNew", color="tab:cyan", linewidth=1) + ax.set_ylabel("Blocks") + ax.legend(loc="upper right", fontsize=8) + ax.set_title("Generation-Phase Allocation (per iteration)", fontsize=10) + _add_phase_markers(ax) + + # --- Panel 6: Onboard Traffic (Host → GPU) --- + ax = axes[5] + onboard_bytes = _field_series(rows, "iterOnboardBytes") + onboard_mib = onboard_bytes / (1024 * 1024) + ax.plot(iters, onboard_mib, label="Onboard (Host→GPU)", color="tab:blue", linewidth=1) + ax.fill_between(iters, onboard_mib, alpha=0.2, color="tab:blue") + ax.set_ylabel("MiB") + ax.legend(loc="upper right", fontsize=8) + ax.set_title("Onboard Traffic — Host → GPU (per iteration)", fontsize=10) + _add_phase_markers(ax) + if onboard_bytes.sum() == 0: + ax.text( + 0.5, + 0.5, + "No onboard transfers (secondary pool inactive)", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="grey", + style="italic", + ) + + # --- Panel 7: Offload Traffic (GPU → Host) --- + ax = axes[6] + offload_bytes = _field_series(rows, "iterOffloadBytes") + offload_mib = offload_bytes / (1024 * 1024) + ax.plot(iters, offload_mib, label="Offload (GPU→Host)", color="tab:red", linewidth=1) + ax.fill_between(iters, offload_mib, alpha=0.2, color="tab:red") + ax.set_ylabel("MiB") + ax.set_xlabel("Iteration Index") + ax.legend(loc="upper right", fontsize=8) + ax.set_title("Offload Traffic — GPU → Host (per iteration)", fontsize=10) + _add_phase_markers(ax) + if offload_bytes.sum() == 0: + ax.text( + 0.5, + 0.5, + "No offload transfers (secondary pool inactive)", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="grey", + style="italic", + ) + + # --- Panel 8: Intra-Device Copy Traffic (GPU → GPU) --- + ax = axes[7] + intra_copy_bytes = _field_series(rows, "iterIntraDeviceCopyBytes") + intra_copy_mib = intra_copy_bytes / (1024 * 1024) + ax.plot( + iters, intra_copy_mib, label="Intra-Device Copy (GPU→GPU)", color="tab:olive", linewidth=1 + ) + ax.fill_between(iters, intra_copy_mib, alpha=0.2, color="tab:olive") + ax.set_ylabel("MiB") + ax.set_xlabel("Iteration Index") + ax.legend(loc="upper right", fontsize=8) + ax.set_title("Intra-Device Copy — GPU → GPU (per iteration)", fontsize=10) + _add_phase_markers(ax) + if intra_copy_bytes.sum() == 0: + ax.text( + 0.5, + 0.5, + "No intra-device copies", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="grey", + style="italic", + ) + + plt.tight_layout(rect=[0, 0, 1, 0.97]) + return fig + + +def plot_per_window(rows, title_prefix: str = ""): + """One figure per window size with the same 5-panel layout.""" + window_sizes = sorted({r["_windowSize"] for r in rows}) + figs = {} + for ws in window_sizes: + ws_rows = [r for r in rows if r["_windowSize"] == ws] + fig = plot_all(ws_rows, title_prefix=f"{title_prefix}[window={ws}] ") + figs[ws] = fig + return figs + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser( + description="Visualize KV cache iteration statistics from JSON" + ) + parser.add_argument("json_file", help="Path to the stats JSON file") + parser.add_argument("--output", "-o", help="Output image path (default: display interactively)") + parser.add_argument( + "--per-window", action="store_true", help="Generate separate charts per window size" + ) + parser.add_argument("--dpi", type=int, default=150, help="DPI for saved images (default: 150)") + args = parser.parse_args() + + payload = load_stats(args.json_file) + model = payload.get("model", "unknown") + num_entries = payload.get("num_entries", "?") + print(f"Model: {model}") + print(f"Total stats entries: {num_entries}") + + if args.per_window: + rows = extract_kv_entries(payload, per_window=True) + window_sizes = sorted({r["_windowSize"] for r in rows}) + print(f"Window sizes: {window_sizes}") + kv_count = len(rows) + print(f"Entries with kvCacheIterationStats: {kv_count}") + if kv_count == 0: + print("No kvCacheIterationStats found in any entry.", file=sys.stderr) + sys.exit(1) + + figs = plot_per_window(rows, title_prefix=f"{model} — ") + if args.output: + out = Path(args.output) + for ws, fig in figs.items(): + p = out.with_stem(f"{out.stem}_window{ws}") + fig.savefig(p, dpi=args.dpi, bbox_inches="tight") + print(f"Saved: {p}") + plt.close("all") + else: + plt.show() + else: + rows = extract_kv_entries(payload, per_window=False) + kv_count = len(rows) + print(f"Entries with kvCacheIterationStats: {kv_count}") + if kv_count == 0: + print("No kvCacheIterationStats found in any entry.", file=sys.stderr) + sys.exit(1) + + fig = plot_all(rows, title_prefix=f"{model} — ") + if args.output: + fig.savefig(args.output, dpi=args.dpi, bbox_inches="tight") + print(f"Saved: {args.output}") + plt.close(fig) + else: + plt.show() + + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 1e51fa55fa9..34012dde254 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -516,6 +516,11 @@ def on_detected(): self.stats_lock = threading.Lock() self.stats = [] + self._latest_kv_iter_stats = None + self._last_kv_iter_stats_fetch_iter = None + self._kv_iter_stats_interval = getattr( + getattr(self.llm_args, 'kv_cache_config', None), + 'iteration_stats_interval', 1) self.gather_all_responses = False self.kv_cache_transceiver = kv_cache_transceiver @@ -1094,6 +1099,17 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, kv_stats_to_save.cache_hit_rate = kv_stats.cache_hit_rate stats.kv_cache_stats = kv_stats_to_save + # Collect per-iteration stats (with deltas) at configured interval. + # Between calls, C++ deltas accumulate so the reported values cover multiple iterations. + # Guard: only fetch once per iter_counter to avoid draining deltas in PP multi-batch. + if (self.iter_counter % self._kv_iter_stats_interval == 0 and + self._last_kv_iter_stats_fetch_iter != self.iter_counter): + self._latest_kv_iter_stats = kv_cache_manager.get_iteration_stats( + ) + self._last_kv_iter_stats_fetch_iter = self.iter_counter + else: + self._latest_kv_iter_stats = None + stats.inflight_batching_stats.num_context_requests = scheduled_batch.num_context_requests stats.inflight_batching_stats.num_gen_requests = scheduled_batch.num_generation_requests stats.inflight_batching_stats.num_scheduled_requests = stats.inflight_batching_stats.num_context_requests + stats.inflight_batching_stats.num_gen_requests @@ -1161,7 +1177,7 @@ def _append_iter_stats(self, with self.stats_lock: if len(self.stats) > self.max_stats_len: self.stats.pop(0) - self.stats.append((stats, req_stats)) + self.stats.append((stats, req_stats, self._latest_kv_iter_stats)) def _process_iter_stats( self, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 971ff8c5402..3c0fa9e6011 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1192,6 +1192,10 @@ def snapshot_warmup_baseline(self): self._warmup_reused_blocks = raw.reused_blocks self._warmup_missed_blocks = raw.missed_blocks + def get_iteration_stats(self): + """Get per-iteration KV cache stats keyed by window size. Resets deltas on each call.""" + return self.impl.get_iteration_stats() + def rewind_kv_cache(self, request: LlmRequest, rewind_len: int): self.impl.rewind_kv_cache(request.py_request_id, rewind_len) @@ -2208,6 +2212,10 @@ def get_kv_cache_stats(self): return kv_cache_stats + def get_iteration_stats(self): + """V2 does not support per-iteration stats yet.""" + return None + def get_block_ids_per_seq(self, request_ids: List[int]) -> torch.Tensor: block_ids_per_seq = self.get_batch_cache_indices(request_ids) block_ids_per_seq_tensors = [ diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index d2ac243a027..6e0b5b2b931 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -304,7 +304,7 @@ def fetch_stats(self) -> list: iter_stats = self.engine.get_latest_iteration_stats() #TODO: Support req stats with TRT engine # This would require ensuring iter and req stats have same size - return [(iter_stat, None) for iter_stat in iter_stats] + return [(iter_stat, None, None) for iter_stat in iter_stats] else: return self.engine.get_latest_iteration_stats() @@ -669,9 +669,9 @@ def get_disaggregated_params(self) -> dict: # Define a Callable to join iteration and request stats @staticmethod - def _stats_serializer( - stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str: - iteration_stats, req_stats = stats + def _stats_serializer(stats) -> str: + iteration_stats, req_stats = stats[0], stats[1] + kv_iter_stats = stats[2] if len(stats) > 2 else None stats_dict = json.loads(iteration_stats.to_json_str()) if req_stats is not None and len(req_stats) > 0: @@ -680,6 +680,35 @@ def _stats_serializer( stats_dict["requestStats"].append( json.loads(req_stat.to_json_str())) + # Inject per-iteration KV cache stats (keyed by window size) + if kv_iter_stats is not None: + stats_dict["kvCacheIterationStats"] = { + str(window_size): { + "primaryMaxNumBlocks": s.primary_max_num_blocks, + "primaryFreeNumBlocks": s.primary_free_num_blocks, + "primaryUsedNumBlocks": s.primary_used_num_blocks, + "secondaryMaxNumBlocks": s.secondary_max_num_blocks, + "secondaryFreeNumBlocks": s.secondary_free_num_blocks, + "secondaryUsedNumBlocks": s.secondary_used_num_blocks, + "iterAllocTotalBlocks": s.iter_alloc_total_blocks, + "iterAllocNewBlocks": s.iter_alloc_new_blocks, + "iterReusedBlocks": s.iter_reused_blocks, + "iterFullReusedBlocks": s.iter_full_reused_blocks, + "iterPartialReusedBlocks": s.iter_partial_reused_blocks, + "iterMissedBlocks": s.iter_missed_blocks, + "iterCacheHitRate": s.iter_cache_hit_rate, + "iterGenAllocBlocks": s.iter_gen_alloc_blocks, + "iterOnboardBlocks": s.iter_onboard_blocks, + "iterOnboardBytes": s.iter_onboard_bytes, + "iterOffloadBlocks": s.iter_offload_blocks, + "iterOffloadBytes": s.iter_offload_bytes, + "iterIntraDeviceCopyBlocks": + s.iter_intra_device_copy_blocks, + "iterIntraDeviceCopyBytes": s.iter_intra_device_copy_bytes, + } + for window_size, s in kv_iter_stats.items() + } + # Convert back to JSON string return json.dumps(stats_dict) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 30cccbc88f9..cac912bb3e4 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2248,6 +2248,15 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): "The maximum size in bytes of GPU memory that can be allocated for the KV cache. If both `max_gpu_total_bytes` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be allocated." ) + # This is a pure python field, not a pybind field. It is only for the Pytorch backend. + iteration_stats_interval: PositiveInt = Field( + default=1, + description= + "How often (in iterations) to collect per-iteration KV cache statistics. " + "A value of 1 means every iteration; a value of N means every Nth iteration. " + "Between collections, the C++ deltas accumulate, so the reported deltas cover N iterations." + ) + # This is a pure python field, not a pybind field. It is only for the Pytorch backend. dtype: str = Field( default="auto", diff --git a/tensorrt_llm/metrics/collector.py b/tensorrt_llm/metrics/collector.py index 3f4a1dc8a68..2215fc73454 100644 --- a/tensorrt_llm/metrics/collector.py +++ b/tensorrt_llm/metrics/collector.py @@ -42,6 +42,16 @@ class MetricsCollector: trtllm_kv_cache_reused_blocks_total trtllm_kv_cache_missed_blocks_total trtllm_kv_cache_utilization + trtllm_kv_cache_host_utilization + trtllm_kv_cache_iter_reuse_rate + trtllm_kv_cache_iter_reused_blocks_total + trtllm_kv_cache_iter_full_reused_blocks_total + trtllm_kv_cache_iter_partial_reused_blocks_total + trtllm_kv_cache_iter_missed_blocks_total + trtllm_kv_cache_gen_alloc_blocks_total + trtllm_kv_cache_onboard_bytes_total + trtllm_kv_cache_offload_bytes_total + trtllm_kv_cache_intra_device_copy_bytes_total """ labelname_finish_reason = "finished_reason" @@ -123,6 +133,51 @@ def __init__(self, labels: Dict[str, str]) -> None: self._prev_reused_blocks = 0 self._prev_missed_blocks = 0 + # Per-iteration KV cache gauges + self.kv_cache_host_utilization = Gauge( + name=self.metric_prefix + "kv_cache_host_utilization", + documentation="KV cache host (secondary) pool utilization", + labelnames=self.labels.keys()) + self.kv_cache_iter_reuse_rate = Gauge( + name=self.metric_prefix + "kv_cache_iter_reuse_rate", + documentation="Per-iteration KV cache block reuse rate", + labelnames=self.labels.keys()) + + # Per-iteration KV cache counters (monotonically increasing via accumulated deltas) + self.kv_cache_iter_reused_blocks = Counter( + name=self.metric_prefix + "kv_cache_iter_reused_blocks", + documentation="Total reused KV cache blocks (full + partial)", + labelnames=self.labels.keys()) + self.kv_cache_iter_full_reused_blocks = Counter( + name=self.metric_prefix + "kv_cache_iter_full_reused_blocks", + documentation="Total fully reused KV cache blocks", + labelnames=self.labels.keys()) + self.kv_cache_iter_partial_reused_blocks = Counter( + name=self.metric_prefix + "kv_cache_iter_partial_reused_blocks", + documentation="Total partially reused KV cache blocks", + labelnames=self.labels.keys()) + self.kv_cache_iter_missed_blocks = Counter( + name=self.metric_prefix + "kv_cache_iter_missed_blocks", + documentation="Total missed KV cache blocks (context phase)", + labelnames=self.labels.keys()) + self.kv_cache_gen_alloc_blocks_total = Counter( + name=self.metric_prefix + "kv_cache_gen_alloc_blocks_total", + documentation="Total blocks allocated during generation phase", + labelnames=self.labels.keys()) + self.kv_cache_onboard_bytes_total = Counter( + name=self.metric_prefix + "kv_cache_onboard_bytes_total", + documentation="Total bytes transferred from host to GPU (onboard)", + labelnames=self.labels.keys()) + self.kv_cache_offload_bytes_total = Counter( + name=self.metric_prefix + "kv_cache_offload_bytes_total", + documentation="Total bytes transferred from GPU to host (offload)", + labelnames=self.labels.keys()) + self.kv_cache_intra_device_copy_bytes_total = Counter( + name=self.metric_prefix + "kv_cache_intra_device_copy_bytes_total", + documentation= + "Total bytes copied within GPU (intra-device block copies)", + labelnames=self.labels.keys()) + def _label_merge(self, labels: Dict[str, str]) -> Dict[str, str]: if labels is None or len(labels) == 0: return self.labels @@ -240,3 +295,67 @@ def log_iteration_stats(self, iteration_stats: dict) -> None: if max_num_blocks: utilization = kv_stats["usedNumBlocks"] / max_num_blocks self._log_gauge(self.kv_cache_utilization, utilization) + + # Per-iteration KV cache stats (aggregated across window sizes) + if kv_iter := iteration_stats.get("kvCacheIterationStats"): + # Aggregate across all window sizes + total_secondary_max = 0 + total_secondary_used = 0 + total_reused = 0 + total_full_reused = 0 + total_partial_reused = 0 + total_missed = 0 + total_gen_alloc = 0 + total_onboard_bytes = 0 + total_offload_bytes = 0 + total_intra_device_copy_bytes = 0 + + for ws_stats in kv_iter.values(): + total_secondary_max += ws_stats.get("secondaryMaxNumBlocks", 0) + total_secondary_used += ws_stats.get("secondaryUsedNumBlocks", + 0) + total_reused += ws_stats.get("iterReusedBlocks", 0) + total_full_reused += ws_stats.get("iterFullReusedBlocks", 0) + total_partial_reused += ws_stats.get("iterPartialReusedBlocks", + 0) + total_missed += ws_stats.get("iterMissedBlocks", 0) + total_gen_alloc += ws_stats.get("iterGenAllocBlocks", 0) + total_onboard_bytes += ws_stats.get("iterOnboardBytes", 0) + total_offload_bytes += ws_stats.get("iterOffloadBytes", 0) + total_intra_device_copy_bytes += ws_stats.get( + "iterIntraDeviceCopyBytes", 0) + + # Gauges + if total_secondary_max > 0: + self._log_gauge(self.kv_cache_host_utilization, + total_secondary_used / total_secondary_max) + iter_total = total_reused + total_missed + if iter_total > 0: + self._log_gauge(self.kv_cache_iter_reuse_rate, + total_reused / iter_total) + + # Counters (increment by delta) + if total_reused > 0: + self._log_counter(self.kv_cache_iter_reused_blocks, {}, + total_reused) + if total_full_reused > 0: + self._log_counter(self.kv_cache_iter_full_reused_blocks, {}, + total_full_reused) + if total_partial_reused > 0: + self._log_counter(self.kv_cache_iter_partial_reused_blocks, {}, + total_partial_reused) + if total_missed > 0: + self._log_counter(self.kv_cache_iter_missed_blocks, {}, + total_missed) + if total_gen_alloc > 0: + self._log_counter(self.kv_cache_gen_alloc_blocks_total, {}, + total_gen_alloc) + if total_onboard_bytes > 0: + self._log_counter(self.kv_cache_onboard_bytes_total, {}, + total_onboard_bytes) + if total_offload_bytes > 0: + self._log_counter(self.kv_cache_offload_bytes_total, {}, + total_offload_bytes) + if total_intra_device_copy_bytes > 0: + self._log_counter(self.kv_cache_intra_device_copy_bytes_total, + {}, total_intra_device_copy_bytes) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index bebf897bf3d..1e1b6a401f5 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -832,10 +832,13 @@ async def _iteration_stats_collector_loop(self): Background task that continuously collects iteration statistics from the LLM engine. This task runs in the background for the lifetime of the server and drains iteration - stats from the engine's stats queue, logging only the latest stats to Prometheus. - Since iteration stats are gauges (point-in-time metrics like KV cache hit rate), - only the most recent value is needed. This approach avoids blocking request completion - while collecting stats and minimizes redundant metric updates. + stats from the engine's stats queue, logging every stat to Prometheus. Gauges + (kv_cache_hit_rate, kv_cache_utilization, kv_cache_iter_reuse_rate) are naturally + overwritten with the latest value, while counters (missed_blocks_total, + gen_alloc_blocks_total, etc.) must be incremented by *every* per-iteration delta + to remain accurate. Logging only the latest stat would drop counter deltas from + earlier iterations and could leave gauges unset if the latest iteration had no + context-phase activity. The task sleeps when idle and is woken up via _iteration_stats_wakeup_event when requests complete. @@ -851,17 +854,11 @@ async def _iteration_stats_collector_loop(self): # Clear the event for next wakeup self._iteration_stats_wakeup_event.clear() - # Drain all available iteration stats from the queue, but only log the latest - # Since metrics are gauges (point-in-time values), only the most recent stat matters + # Drain all available iteration stats and log each one to Prometheus. try: - latest_stat = None async for llm_stat in self.generator.get_stats_async( timeout=0.5): - latest_stat = llm_stat # Keep only the latest - - # Log only the most recent iteration stats to Prometheus - if latest_stat is not None: - self.metrics_collector.log_iteration_stats(latest_stat) + self.metrics_collector.log_iteration_stats(llm_stat) except Exception as e: # Log errors but continue collecting stats logger.error(f"Error collecting iteration stats: {e}", diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 39feb299531..1d856ab0ddc 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1483,6 +1483,78 @@ def test_auto_dtype_vswa_reuse_low_memory_available_partial_reuse(self): task = MMLU(self.MODEL_NAME) task.evaluate(llm) + def test_auto_dtype_vswa_reuse_kv_cache_stats(self): + """Mirror of test_auto_dtype_vswa_reuse that collects per-iteration stats. + + Collects per-iteration KV cache statistics and writes them to a JSON + file for offline visualization with + ``scripts/visualize_kv_cache_stats.py``. + """ + import json + import time + from pathlib import Path + + kv_cache_config = KvCacheConfig( + enable_block_reuse=True, + max_attention_window=[512, 512, 512, 512, 512, 32768], + iteration_stats_interval=1, + ) + + all_stats = [] + + def drain_stats(llm, phase_label): + """Drain the stats queue and tag each entry. + + Tags each entry with wall-clock time and a human-readable phase + label. + """ + stats = llm.get_stats(timeout=2) + ts = time.time() + for entry in stats: + entry["_collectedAt"] = ts + entry["_phase"] = phase_label + all_stats.extend(stats) + + with LLM( + self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_iter_perf_stats=True, + ) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + drain_stats(llm, "GSM8K") + + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + drain_stats(llm, "MMLU") + + # Write collected stats to JSON + out_dir = Path( + os.environ.get( + "KV_CACHE_STATS_OUTPUT_DIR", + "kv_cache_stats_output", + )) + out_dir.mkdir(parents=True, exist_ok=True) + timestamp = time.strftime("%Y%m%d_%H%M%S") + out_path = out_dir / f"kv_cache_stats_{timestamp}.json" + + payload = { + "model": self.MODEL_NAME, + "kv_cache_config": { + "enable_block_reuse": + kv_cache_config.enable_block_reuse, + "max_attention_window": + kv_cache_config.max_attention_window, + "iteration_stats_interval": + kv_cache_config.iteration_stats_interval, + }, + "num_entries": len(all_stats), + "stats": all_stats, + } + out_path.write_text(json.dumps(payload, indent=2)) + print(f"\n[kv_cache_stats] Wrote {len(all_stats)} entries to " + f"{out_path}") + def test_auto_dtype_vswa_chunked_prefill_without_reuse(self): # NOTE: Test with VSWA kv cache config. kv_cache_config = KvCacheConfig( @@ -2728,13 +2800,13 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size, @skip_pre_blackwell @pytest.mark.skip_less_device(8) def test_nvfp4_multi_gpus_corner_case(self): - """ - This test is used to test the corner case of the NVFP4 model. - When using the same value for max_seq_len and max_num_tokens, there will be no - enough kv block for the dummy requests in CUDA graph warmup when creating - the py_executor before estimating kv cache. Then CUDA graph capture will be - triggered when estimating kv cache. This may cause some errors. - More info in https://nvbugs/5485325. + """Test the corner case of the NVFP4 model. + + When using the same value for max_seq_len and max_num_tokens, there will + be no enough kv block for the dummy requests in CUDA graph warmup when + creating the py_executor before estimating kv cache. Then CUDA graph + capture will be triggered when estimating kv cache. This may cause some + errors. More info in https://nvbugs/5485325. """ kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80, dtype="fp8", @@ -3547,8 +3619,8 @@ def test_nvfp4(self, tp_size): "ignore:.*configuration is not supported by the fused routing kernel.*:UserWarning" ) def test_nvfp4_longseq_trtllm_moe_stress(self, mocker): - """ - Long-sequence MoE stress test with PDL enabled. + """Long-sequence MoE stress test with PDL enabled. + RCCA: https://nvbugspro.nvidia.com/bug/5661741 """ patch_mpi_pool_session_for_env(mocker, {"TRTLLM_ENABLE_PDL": "1"}) @@ -3629,8 +3701,8 @@ def test_nvfp4_longseq_trtllm_moe_stress(self, mocker): "ignore:.*configuration is not supported by the fused routing kernel.*:UserWarning" ) def test_nvfp4_longseq_trtllm_moe_async_cancel(self, mocker): - """ - Long-sequence MoE async streaming test with cancellation. + """Long-sequence MoE async streaming test with cancellation. + RCCA: https://nvbugspro.nvidia.com/bug/5661741 """ patch_mpi_pool_session_for_env(mocker, {"TRTLLM_ENABLE_PDL": "1"}) diff --git a/tests/integration/defs/kv_cache/conftest.py b/tests/integration/defs/kv_cache/conftest.py new file mode 100644 index 00000000000..d580b54f30c --- /dev/null +++ b/tests/integration/defs/kv_cache/conftest.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Used by kv_cache/test_kv_cache_iteration_stats.py for --verbose-stats option. + + +def pytest_addoption(parser): + parser.addoption( + "--verbose-stats", + action="store_true", + default=False, + help="Dump all 18 KV cache stat fields for every stats entry", + ) diff --git a/tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py b/tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py new file mode 100644 index 00000000000..6319aa85976 --- /dev/null +++ b/tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py @@ -0,0 +1,505 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for per-iteration KV cache statistics (kvCacheIterationStats). + +Tests verify that the 18 stat fields are correctly populated across +different inference scenarios: cold start, block reuse (partial/full), +shared prefix, batch generation, long context, and rapid-fire. + +Usage: + # Via pytest (recommended): + pytest tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py + pytest tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py -k "cold_start" + pytest tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py -s # show prints + pytest tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py -s --verbose-stats + + # Standalone (still supported): + python3 tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py + python3 tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py --verbose + python3 tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py --test 2 3 + python3 tests/integration/defs/kv_cache/test_kv_cache_iteration_stats.py --list +""" + +import argparse + +import pytest + +from tensorrt_llm import LLM +from tensorrt_llm.llmapi import KvCacheConfig +from tensorrt_llm.sampling_params import SamplingParams + +from ..conftest import llm_models_root + +MODEL = f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + +ALL_FIELDS = [ + # Instantaneous gauges — primary (GPU) pool + "primaryMaxNumBlocks", + "primaryFreeNumBlocks", + "primaryUsedNumBlocks", + # Instantaneous gauges — secondary (host) pool + "secondaryMaxNumBlocks", + "secondaryFreeNumBlocks", + "secondaryUsedNumBlocks", + # Per-iteration deltas — context phase + "iterAllocTotalBlocks", + "iterAllocNewBlocks", + "iterReusedBlocks", + "iterFullReusedBlocks", + "iterPartialReusedBlocks", + "iterMissedBlocks", + "iterCacheHitRate", + # Per-iteration deltas — generation phase + "iterGenAllocBlocks", + # Per-iteration deltas — transfer traffic + "iterOnboardBlocks", + "iterOnboardBytes", + "iterOffloadBlocks", + "iterOffloadBytes", + # Intra-device (GPU → GPU) block copies + "iterIntraDeviceCopyBlocks", + "iterIntraDeviceCopyBytes", +] + +TEST_NAMES = { + 1: "Cold start", + 2: "Partial block reuse", + 3: "Full block reuse", + 4: "Shared prefix", + 5: "Batch generation", + 6: "Long context", + 7: "Rapid-fire", + 8: "Field completeness", +} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _is_verbose(request): + """Check if verbose stats output is requested (pytest or standalone).""" + if request is not None: + return request.config.getoption("--verbose-stats", default=False) + return False + + +def print_kv_stats(label, stats_list): + """Print all 18 fields for every stats entry.""" + print(f"\n{'=' * 60}") + print(f" {label}: {len(stats_list)} stats entries") + print(f"{'=' * 60}") + found = False + for i, s in enumerate(stats_list): + ki = s.get("kvCacheIterationStats") + if ki: + found = True + for ws, v in ki.items(): + print(f"\n --- entry[{i}] window_size={ws} ---") + for field in ALL_FIELDS: + val = v.get(field, "") + print(f" {field:30s} = {val}") + else: + keys = list(s.keys())[:8] + print(f" entry[{i}]: no kvCacheIterationStats (keys: {keys})") + if not found: + print(" WARNING: no entry contained kvCacheIterationStats!") + + +def collect_stats(llm, all_collected): + """Get stats and append to the cumulative list.""" + stats = llm.get_stats(timeout=2) + all_collected.extend(stats) + return stats + + +def find_kv_entries(stats_list): + """Extract all (entry_index, window_size, fields_dict) from stats.""" + results = [] + for i, s in enumerate(stats_list): + ki = s.get("kvCacheIterationStats") + if ki: + for ws, v in ki.items(): + results.append((i, ws, v)) + return results + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture(scope="module") +def llm_instance(): + """Create a shared LLM instance for all tests in this module.""" + llm = LLM( + model=MODEL, + kv_cache_config=KvCacheConfig(enable_block_reuse=True, iteration_stats_interval=1), + enable_iter_perf_stats=True, + return_perf_metrics=True, + ) + yield llm + llm.shutdown() + + +@pytest.fixture(scope="module") +def all_collected(): + """Shared list to accumulate stats across tests.""" + return [] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +@pytest.mark.threadleak(enabled=False) +class TestKvCacheIterationStats: + def test_cold_start(self, llm_instance, all_collected, request): + """Cold start — 3 unique prompts, no reuse expected.""" + prompts = [ + "The history of ancient Rome begins with", + "Photosynthesis is the process by which plants", + "The Pythagorean theorem states that in a right triangle", + ] + # Collect stats between each generate() call so that the context-phase + # iteration stats (where iterMissedBlocks > 0) are captured before being + # diluted by many generation/idle iterations with zero deltas. + all_kv = [] + for p in prompts: + o = llm_instance.generate([p], SamplingParams(max_tokens=64)) + print(f" Input: {p[:50]}...") + print(f" Output: {o[0].outputs[0].text!r}") + stats = collect_stats(llm_instance, all_collected) + if _is_verbose(request): + print_kv_stats(f"Cold start ({p[:20]}...)", stats) + all_kv.extend(find_kv_entries(stats)) + + assert len(all_kv) > 0, "kvCacheIterationStats not found" + + # The overlap scheduler processes stats for the *previous* batch, + # so per-iteration deltas may not align with the iteration that + # allocated context blocks. Check the cumulative kvCacheStats first + # (always reliable), then the per-iteration deltas. + cumul_missed = max( + (s.get("kvCacheStats", {}).get("missedBlocks", 0) for s in all_collected), + default=0, + ) + iter_missed_found = any(v["iterMissedBlocks"] > 0 for _, _, v in all_kv) + assert cumul_missed > 0 or iter_missed_found, ( + "No cold misses detected: cumulative missedBlocks = " + f"{cumul_missed}, iterMissedBlocks > 0 in {sum(1 for _, _, v in all_kv if v['iterMissedBlocks'] > 0)}" + f"/{len(all_kv)} entries" + ) + assert any(v["iterAllocTotalBlocks"] > 0 for _, _, v in all_kv), ( + "iterAllocTotalBlocks = 0 in all entries" + ) + assert any(v["iterGenAllocBlocks"] > 0 for _, _, v in all_kv), ( + "iterGenAllocBlocks = 0 in all entries" + ) + + def test_partial_block_reuse(self, llm_instance, all_collected, request): + """Partial block reuse — short prompt (< 1 block) repeated x3. + + With tokens_per_block=32, a short prompt fits within a single block + without filling it completely. On repeat, the block is reused but + classified as partial (iterPartialReusedBlocks). + """ + repeated = "The theory of general relativity tells us that gravity is" + print(f" Input: {repeated!r}") + for i in range(3): + o = llm_instance.generate([repeated], SamplingParams(max_tokens=64)) + print(f" Output (repeat {i + 1}): {o[0].outputs[0].text!r}") + stats = collect_stats(llm_instance, all_collected) + if _is_verbose(request): + print_kv_stats("Partial block reuse", stats) + kv = find_kv_entries(stats) + + assert len(kv) > 0, "kvCacheIterationStats not found" + assert any(v["iterPartialReusedBlocks"] > 0 for _, _, v in kv), ( + "iterPartialReusedBlocks = 0 in all entries (expected partial reuse)" + ) + assert any(v["iterCacheHitRate"] > 0 for _, _, v in kv), ( + "iterCacheHitRate = 0 in all entries (expected cache hits)" + ) + + def test_full_block_reuse(self, llm_instance, all_collected, request): + """Full block reuse — prompt spanning 3+ blocks, repeated. + + With tokens_per_block=32, a prompt of ~120 tokens spans ~4 blocks. + The first N-1 fully-filled blocks should register as iterFullReusedBlocks + on the second request. + """ + long_prompt = ( + "The quick brown fox jumps over the lazy dog and then runs across " + "the wide open field where the tall green grass sways gently in the " + "warm summer breeze while birds sing melodiously in the trees above " + "and the river flows calmly through the valley carrying leaves and " + "small stones downstream toward the distant ocean where waves crash " + "against the rocky shore and seagulls circle overhead looking for " + "fish beneath the sparkling surface of the deep blue water that" + ) + + # First request — cold, populates the radix tree + print(f" Input: {long_prompt[:80]!r}... (~{len(long_prompt.split())} words)") + o1 = llm_instance.generate([long_prompt], SamplingParams(max_tokens=16)) + print(f" Output (1st, cold): {o1[0].outputs[0].text!r}") + collect_stats(llm_instance, all_collected) # drain stats from first request + + # Second request — identical prompt, should reuse full blocks + o2 = llm_instance.generate([long_prompt], SamplingParams(max_tokens=16)) + print(f" Output (2nd, warm): {o2[0].outputs[0].text!r}") + stats = collect_stats(llm_instance, all_collected) + if _is_verbose(request): + print_kv_stats("Full block reuse", stats) + kv = find_kv_entries(stats) + + assert len(kv) > 0, "kvCacheIterationStats not found" + max_full = max(v["iterFullReusedBlocks"] for _, _, v in kv) + max_reused = max(v["iterReusedBlocks"] for _, _, v in kv) + assert max_full > 0, "iterFullReusedBlocks = 0 in all entries (expected full block reuse)" + assert max_reused > max_full, ( + f"iterReusedBlocks = {max_reused} == iterFullReusedBlocks = {max_full} " + "(expected at least one partial block too)" + ) + + def test_shared_prefix(self, llm_instance, all_collected, request): + """Shared prefix — common prefix, 5 different suffixes.""" + prefix = "In the field of machine learning, neural networks are commonly used for " + suffixes = [ + "image classification where the input data is", + "natural language processing where the model learns to", + "training with backpropagation which involves computing", + "building layers of neurons that can represent", + "learning complex patterns in data such as", + ] + print(f" Prefix: {prefix!r}") + for s in suffixes: + full = prefix + s + o = llm_instance.generate([full], SamplingParams(max_tokens=64)) + print(f" Input: ...{s!r}") + print(f" Output: {o[0].outputs[0].text!r}") + stats = collect_stats(llm_instance, all_collected) + if _is_verbose(request): + print_kv_stats("Shared prefix", stats) + kv = find_kv_entries(stats) + + assert len(kv) > 0, "kvCacheIterationStats not found" + assert any(v["iterReusedBlocks"] > 0 for _, _, v in kv), ( + "iterReusedBlocks = 0 in all entries (expected prefix reuse)" + ) + + def test_batch_generation(self, llm_instance, all_collected, request): + """Batch generation — 4 prompts in one generate() call.""" + batch = [ + "The capital of France is known for its", + "The capital of Germany is a city that", + "The capital of Japan is famous for its", + "The capital of Brazil was designed by", + ] + outputs = llm_instance.generate(batch, SamplingParams(max_tokens=64)) + for p, o in zip(batch, outputs): + print(f" Input: {p!r}") + print(f" Output: {o.outputs[0].text!r}") + stats = collect_stats(llm_instance, all_collected) + if _is_verbose(request): + print_kv_stats("Batch generation", stats) + kv = find_kv_entries(stats) + + assert len(kv) > 0, "kvCacheIterationStats not found" + assert any(v["iterAllocTotalBlocks"] > 0 for _, _, v in kv), ( + "iterAllocTotalBlocks = 0 in all entries" + ) + assert any(v["iterGenAllocBlocks"] > 0 for _, _, v in kv), ( + "iterGenAllocBlocks = 0 in all entries" + ) + assert any(v["primaryUsedNumBlocks"] > 0 for _, _, v in kv), ( + "primaryUsedNumBlocks = 0 in all entries" + ) + + def test_long_context(self, llm_instance, all_collected, request): + """Long context — single long prompt to allocate many blocks.""" + long_prompt = " ".join( + [ + "The quick brown fox jumps over the lazy dog.", + "A journey of a thousand miles begins with a single step.", + "To be or not to be, that is the question.", + "All that glitters is not gold.", + "The only thing we have to fear is fear itself.", + "In the beginning, there was nothing but darkness and void.", + "Science is organized knowledge; wisdom is organized life.", + "The unexamined life is not worth living.", + "I think, therefore I am.", + "That which does not kill us makes us stronger.", + "The greatest glory in living lies not in never falling,", + "but in rising every time we fall.", + "Life is what happens when you are busy making other plans.", + "The way to get started is to quit talking and begin doing.", + "If you look at what you have in life,", + "you will always have more. In conclusion, the meaning of", + ] + ) + o = llm_instance.generate([long_prompt], SamplingParams(max_tokens=128)) + print(f" Input: ({len(long_prompt.split())} words) {long_prompt[:80]!r}...") + print(f" Output: {o[0].outputs[0].text!r}") + stats = collect_stats(llm_instance, all_collected) + if _is_verbose(request): + print_kv_stats("Long context", stats) + kv = find_kv_entries(stats) + + assert len(kv) > 0, "kvCacheIterationStats not found" + max_used = max(v["primaryUsedNumBlocks"] for _, _, v in kv) + max_alloc = max(v["iterAllocTotalBlocks"] for _, _, v in kv) + assert max_used > 0, "primaryUsedNumBlocks = 0 in all entries" + assert max_alloc > 0, "iterAllocTotalBlocks = 0 in all entries" + assert any(v["iterGenAllocBlocks"] > 0 for _, _, v in kv), ( + "iterGenAllocBlocks = 0 in all entries" + ) + + def test_rapid_fire(self, llm_instance, all_collected, request): + """Rapid-fire — 20 short requests to accumulate deltas.""" + for i in range(20): + prompt = f"Count to {i}: " + o = llm_instance.generate([prompt], SamplingParams(max_tokens=32)) + if i % 5 == 0: + print(f" Input: {prompt!r}") + print(f" Output: {o[0].outputs[0].text!r}") + stats = collect_stats(llm_instance, all_collected) + if _is_verbose(request): + print_kv_stats("Rapid-fire", stats) + kv = find_kv_entries(stats) + + assert len(kv) > 0, "kvCacheIterationStats not found" + total_gen = sum(v["iterGenAllocBlocks"] for _, _, v in kv) + total_alloc = sum(v["iterAllocTotalBlocks"] for _, _, v in kv) + assert total_gen > 0, "iterGenAllocBlocks = 0 across all entries" + assert total_alloc > 0, "iterAllocTotalBlocks = 0 across all entries" + + def test_field_completeness(self, llm_instance, all_collected, request): + """Field completeness — verify all 18 fields present across all collected stats.""" + # If running standalone (no prior tests), generate some traffic + if not all_collected: + llm_instance.generate(["Hello world"], SamplingParams(max_tokens=16)) + collect_stats(llm_instance, all_collected) + + entries_with_kv = 0 + missing_fields = set() + for s in all_collected: + ki = s.get("kvCacheIterationStats") + if ki: + entries_with_kv += 1 + for ws, v in ki.items(): + for field in ALL_FIELDS: + if field not in v: + missing_fields.add(field) + + print(f" Entries with kvCacheIterationStats: {entries_with_kv}/{len(all_collected)}") + assert entries_with_kv > 0, "no entries contain kvCacheIterationStats" + assert len(missing_fields) == 0, f"Missing fields: {sorted(missing_fields)}" + + +# --------------------------------------------------------------------------- +# Standalone execution (python3 ... directly) +# --------------------------------------------------------------------------- +_STANDALONE_TEST_FUNCS = { + 1: "test_cold_start", + 2: "test_partial_block_reuse", + 3: "test_full_block_reuse", + 4: "test_shared_prefix", + 5: "test_batch_generation", + 6: "test_long_context", + 7: "test_rapid_fire", + 8: "test_field_completeness", +} + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Dump all 18 KV cache stat fields for every stats entry", + ) + parser.add_argument( + "--test", + "-t", + type=int, + nargs="+", + metavar="N", + help="Run only the specified test(s) by number (e.g. --test 2 3)", + ) + parser.add_argument("--list", "-l", action="store_true", help="List available tests and exit") + args = parser.parse_args() + + if args.list: + print("Available tests:") + for num, name in TEST_NAMES.items(): + print(f" {num} {name}") + return + + selected = args.test if args.test else sorted(_STANDALONE_TEST_FUNCS.keys()) + invalid = [t for t in selected if t not in _STANDALONE_TEST_FUNCS] + if invalid: + parser.error(f"Unknown test number(s): {invalid}. Use --list to see available tests.") + + # Create a fake request object for verbose flag + class FakeConfig: + def getoption(self, name, default=False): + return args.verbose + + class FakeRequest: + config = FakeConfig() + + fake_request = FakeRequest() + + print("Starting LLM with block_reuse + iteration_stats_interval=1") + llm = LLM( + model=MODEL, + kv_cache_config=KvCacheConfig(enable_block_reuse=True, iteration_stats_interval=1), + enable_iter_perf_stats=True, + return_perf_metrics=True, + ) + all_collected = [] + results = {} + test_cls = TestKvCacheIterationStats() + + for t in selected: + name = f"Test {t}: {TEST_NAMES[t]}" + method = getattr(test_cls, _STANDALONE_TEST_FUNCS[t]) + try: + method(llm, all_collected, fake_request) + results[name] = True + print(" PASS") + except AssertionError as e: + results[name] = False + print(f" FAIL: {e}") + + llm.shutdown() + + print(f"\n{'=' * 60}") + print(" SUMMARY") + print(f"{'=' * 60}") + passed = sum(1 for v in results.values() if v) + total = len(results) + for name, ok in results.items(): + print(f" {'PASS' if ok else 'FAIL'}: {name}") + print(f"\n {passed}/{total} tests passed.") + if passed < total: + print(" Some tests FAILED — review output above.") + else: + print(" All tests passed!") + + +if __name__ == "__main__": + main() diff --git a/tests/integration/defs/perf/pytorch_model_config.py b/tests/integration/defs/perf/pytorch_model_config.py index e480355084a..70c22fdb083 100644 --- a/tests/integration/defs/perf/pytorch_model_config.py +++ b/tests/integration/defs/perf/pytorch_model_config.py @@ -93,7 +93,6 @@ def get_model_yaml_config(model_label: str, 'enable_padding': False }, 'moe_config': { - 'backend': 'TRTLLM', 'max_num_tokens': 32768 }, 'speculative_config': { @@ -228,18 +227,6 @@ def get_model_yaml_config(model_label: str, 'enable_attention_dp': True, } }, - # Qwen3 models with fp4 quantization on B200 with moe backend equal to TRTLLM - { - 'patterns': [ - 'qwen3_235b_a22b_fp4-bench-pytorch-float4-maxbs:512-maxnt:2048-input_output_len:1000,2000-con:8-ep:8-gpus:8', - ], - 'config': { - 'enable_attention_dp': False, - 'moe_config': { - 'backend': 'TRTLLM' - } - } - }, { 'patterns': [ 'qwen3_4b-bench-pytorch-streaming-bfloat16-maxbs:4-kv_frac:0.6-input_output_len:500,100-reqs:200-con:4', @@ -317,9 +304,6 @@ def get_model_yaml_config(model_label: str, 'enable_padding': True, 'max_batch_size': 720, }, - 'moe_config': { - 'backend': 'TRTLLM' - }, 'stream_interval': 10, 'num_postprocess_workers': 4 } @@ -336,9 +320,6 @@ def get_model_yaml_config(model_label: str, 'enable_padding': True, 'max_batch_size': 720, }, - 'moe_config': { - 'backend': 'TRTLLM' - }, 'stream_interval': 10, 'num_postprocess_workers': 4 } diff --git a/tests/integration/test_lists/dev/.gitignore b/tests/integration/test_lists/dev/.gitignore deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 0a2c7a9907a..29e4806f395 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -156,6 +156,17 @@ l0_b200: - kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2LoRA::test_lora_multi_adapter_v2 - kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2LoRA::test_lora_chunked_prefill - kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2LoRA::test_lora_eviction + # ------------- KV Cache Iteration Stats --------------- + - unittest/executor/test_stats_serializer.py + - unittest/metrics/test_collector.py + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_cold_start + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_partial_block_reuse + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_full_block_reuse + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_shared_prefix + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_batch_generation + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_long_context + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_rapid_fire + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_field_completeness # ------------- Visual Gen tests --------------- - unittest/_torch/visual_gen/test_visual_gen_args.py - unittest/_torch/visual_gen/test_warmup.py diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index ed1039b81e9..00394f3eec1 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -440,6 +440,17 @@ l0_h100: - unittest/trt/attention/test_gpt_attention_no_cache.py - examples/test_gpt.py::test_gpt_oss_20b_lora_torch[gpt-oss-20b-lora-adapter_NIM_r8-gpt-oss-20b] - unittest/kv_cache_manager_v2_tests/ # 4 min + # ------------- KV Cache Iteration Stats --------------- + - unittest/executor/test_stats_serializer.py + - unittest/metrics/test_collector.py + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_cold_start + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_partial_block_reuse + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_full_block_reuse + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_shared_prefix + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_batch_generation + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_long_context + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_rapid_fire + - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_field_completeness - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 8d7457b3460..7bcc6643dcd 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -334,6 +334,7 @@ examples/test_visual_gen.py::test_vbench_dimension_score_wan22_a14b_fp8 SKIP (ht visual_gen/test_visual_gen_benchmark.py::test_offline_benchmark SKIP (https://nvbugs/6050483) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/6050487) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/6050489) +disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b] SKIP (https://nvbugs/6011317) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/6050489) disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp1-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6057459) disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp4-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6057460) diff --git a/tests/unittest/executor/test_stats_serializer.py b/tests/unittest/executor/test_stats_serializer.py new file mode 100644 index 00000000000..0c00ca9322f --- /dev/null +++ b/tests/unittest/executor/test_stats_serializer.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for _stats_serializer with kvCacheIterationStats injection.""" + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from tensorrt_llm.executor.base_worker import BaseWorker + + +def _make_mock_iteration_stats(kv_cache_stats_json=None): + """Create a mock IterationStats object with to_json_str().""" + base = { + "iter": 1, + "iterLatencyMS": 10.5, + "gpuMemUsage": 1024, + "cpuMemUsage": 0, + "pinnedMemUsage": 0, + } + if kv_cache_stats_json is not None: + base["kvCacheStats"] = kv_cache_stats_json + + mock = MagicMock() + mock.to_json_str.return_value = json.dumps(base) + return mock + + +def _make_mock_kv_iter_stats( + window_size=16, + primary_used=10, + primary_max=20, + reused=5, + full_reused=4, + partial_reused=1, + missed=3, + gen_alloc=2, +): + """Create a mock KvCacheIterationStats nanobind object.""" + s = SimpleNamespace( + primary_max_num_blocks=primary_max, + primary_free_num_blocks=primary_max - primary_used, + primary_used_num_blocks=primary_used, + secondary_max_num_blocks=0, + secondary_free_num_blocks=0, + secondary_used_num_blocks=0, + iter_alloc_total_blocks=reused + missed, + iter_alloc_new_blocks=missed, + iter_reused_blocks=reused, + iter_full_reused_blocks=full_reused, + iter_partial_reused_blocks=partial_reused, + iter_missed_blocks=missed, + iter_cache_hit_rate=reused / (reused + missed) if (reused + missed) > 0 else 0.0, + iter_gen_alloc_blocks=gen_alloc, + iter_onboard_blocks=1, + iter_onboard_bytes=4096, + iter_offload_blocks=0, + iter_offload_bytes=0, + iter_intra_device_copy_blocks=2, + iter_intra_device_copy_bytes=8192, + ) + return {window_size: s} + + +class TestStatsSerializer: + def test_serializer_without_kv_iter_stats(self): + """Legacy 2-tuple and 3-tuple with None should produce same output.""" + iter_stats = _make_mock_iteration_stats() + + # 3-tuple with None kv_iter_stats + result = BaseWorker._stats_serializer((iter_stats, None, None)) + d = json.loads(result) + assert "iter" in d + assert "kvCacheIterationStats" not in d + + def test_serializer_with_kv_iter_stats(self): + """KvCacheIterationStats should appear when provided.""" + iter_stats = _make_mock_iteration_stats( + kv_cache_stats_json={"maxNumBlocks": 20, "usedNumBlocks": 10} + ) + kv_iter = _make_mock_kv_iter_stats( + window_size=16, + primary_used=10, + primary_max=20, + reused=5, + full_reused=4, + partial_reused=1, + missed=3, + gen_alloc=2, + ) + + result = BaseWorker._stats_serializer((iter_stats, None, kv_iter)) + d = json.loads(result) + + # Existing kvCacheStats should still be present + assert "kvCacheStats" in d + + # New kvCacheIterationStats should be present + assert "kvCacheIterationStats" in d + iter_kv = d["kvCacheIterationStats"] + assert "16" in iter_kv # window size key as string + + ws_stats = iter_kv["16"] + assert ws_stats["primaryMaxNumBlocks"] == 20 + assert ws_stats["primaryUsedNumBlocks"] == 10 + assert ws_stats["primaryFreeNumBlocks"] == 10 + assert ws_stats["iterReusedBlocks"] == 5 + assert ws_stats["iterFullReusedBlocks"] == 4 + assert ws_stats["iterPartialReusedBlocks"] == 1 + assert ws_stats["iterMissedBlocks"] == 3 + assert ws_stats["iterGenAllocBlocks"] == 2 + assert ws_stats["iterOnboardBlocks"] == 1 + assert ws_stats["iterOnboardBytes"] == 4096 + assert ws_stats["iterOffloadBlocks"] == 0 + assert ws_stats["iterOffloadBytes"] == 0 + assert ws_stats["iterIntraDeviceCopyBlocks"] == 2 + assert ws_stats["iterIntraDeviceCopyBytes"] == 8192 + assert ws_stats["iterCacheHitRate"] == pytest.approx(5 / 8) + + def test_serializer_multiple_window_sizes(self): + """Multiple window sizes should all appear in output.""" + iter_stats = _make_mock_iteration_stats() + kv_iter = _make_mock_kv_iter_stats( + window_size=16, + primary_used=5, + primary_max=10, + reused=2, + full_reused=2, + partial_reused=0, + missed=1, + gen_alloc=0, + ) + # Add a second window size + kv_iter[64] = _make_mock_kv_iter_stats( + window_size=64, + primary_used=8, + primary_max=16, + reused=3, + full_reused=1, + partial_reused=2, + missed=2, + gen_alloc=1, + )[64] + + result = BaseWorker._stats_serializer((iter_stats, None, kv_iter)) + d = json.loads(result) + + iter_kv = d["kvCacheIterationStats"] + assert "16" in iter_kv + assert "64" in iter_kv + assert iter_kv["16"]["primaryMaxNumBlocks"] == 10 + assert iter_kv["64"]["primaryMaxNumBlocks"] == 16 + + def test_serializer_with_request_stats(self): + """Request stats and kv iter stats should coexist.""" + iter_stats = _make_mock_iteration_stats() + kv_iter = _make_mock_kv_iter_stats() + + req_stat = MagicMock() + req_stat.to_json_str.return_value = json.dumps({"id": 42}) + + result = BaseWorker._stats_serializer((iter_stats, [req_stat], kv_iter)) + d = json.loads(result) + + assert "requestStats" in d + assert len(d["requestStats"]) == 1 + assert d["requestStats"][0]["id"] == 42 + assert "kvCacheIterationStats" in d + + def test_serializer_none_on_off_interval(self): + """When kv_iter_stats is None (off-interval), field should be absent.""" + iter_stats = _make_mock_iteration_stats() + + result = BaseWorker._stats_serializer((iter_stats, None, None)) + d = json.loads(result) + assert "kvCacheIterationStats" not in d + + def test_serializer_legacy_2_tuple(self): + """Legacy 2-tuple without third element should work.""" + iter_stats = _make_mock_iteration_stats() + + result = BaseWorker._stats_serializer((iter_stats, None)) + d = json.loads(result) + assert "kvCacheIterationStats" not in d diff --git a/tests/unittest/llmapi/apps/_test_openai_metrics.py b/tests/unittest/llmapi/apps/_test_openai_metrics.py index 9bccfe2312f..dc40c31a1e2 100644 --- a/tests/unittest/llmapi/apps/_test_openai_metrics.py +++ b/tests/unittest/llmapi/apps/_test_openai_metrics.py @@ -98,3 +98,19 @@ def test_metrics(client): assert "pinnedMemUsage" in response_dict assert "staticBatchingStats" in response_dict assert "timestamp" in response_dict + # Per-iteration KV cache stats (keyed by window size) + assert "kvCacheIterationStats" in response_dict + kv_iter = response_dict["kvCacheIterationStats"] + assert len(kv_iter) > 0 + # Check fields in the first (and likely only) window size entry + ws_stats = next(iter(kv_iter.values())) + assert "primaryMaxNumBlocks" in ws_stats + assert "primaryUsedNumBlocks" in ws_stats + assert "iterReusedBlocks" in ws_stats + assert "iterFullReusedBlocks" in ws_stats + assert "iterPartialReusedBlocks" in ws_stats + assert "iterMissedBlocks" in ws_stats + assert "iterCacheHitRate" in ws_stats + assert "iterGenAllocBlocks" in ws_stats + assert "iterOnboardBlocks" in ws_stats + assert "iterOnboardBytes" in ws_stats diff --git a/tests/unittest/llmapi/apps/_test_openai_prometheus.py b/tests/unittest/llmapi/apps/_test_openai_prometheus.py index 4cd65051fc3..10b33483c10 100644 --- a/tests/unittest/llmapi/apps/_test_openai_prometheus.py +++ b/tests/unittest/llmapi/apps/_test_openai_prometheus.py @@ -45,7 +45,11 @@ def temp_extra_llm_api_options_file(request): try: extra_llm_api_options_dict = { "return_perf_metrics": True, - "enable_iter_perf_stats": True + "enable_iter_perf_stats": True, + "kv_cache_config": { + "enable_block_reuse": True, + "iteration_stats_interval": 1, + }, } with open(temp_file_path, 'w') as f: @@ -113,8 +117,9 @@ def _parse_all_kv_metrics(data: str, prefix: str) -> Dict[str, float | None]: """ names = [ prefix + "kv_cache_hit_rate", - prefix + "kv_cache_reused_blocks_total", - prefix + "kv_cache_missed_blocks_total", + prefix + "kv_cache_iter_reused_blocks_total", + prefix + "kv_cache_iter_missed_blocks_total", + prefix + "kv_cache_iter_reuse_rate", prefix + "kv_cache_utilization", ] return {name: _parse_prometheus_sample(data, name) for name in names} @@ -188,18 +193,22 @@ def test_metrics_endpoint(server: RemoteOpenAIServer): for name, value in kv_metrics.items(): assert value is not None, f"No sample value found for {name}" - # Verify post-warmup values match expected behavior: - # Two identical requests → 1 reused block, 1 missed block, 0.5 hit rate + # Verify post-warmup values are populated and non-zero. + # Exact values are non-deterministic because counters accumulate across + # all scheduler iterations between request completion and metric scrape. hit_rate = kv_metrics[METRIC_PREFIX + "kv_cache_hit_rate"] - reused = kv_metrics[METRIC_PREFIX + "kv_cache_reused_blocks_total"] - missed = kv_metrics[METRIC_PREFIX + "kv_cache_missed_blocks_total"] + reused = kv_metrics[METRIC_PREFIX + "kv_cache_iter_reused_blocks_total"] + missed = kv_metrics[METRIC_PREFIX + "kv_cache_iter_missed_blocks_total"] utilization = kv_metrics[METRIC_PREFIX + "kv_cache_utilization"] - assert hit_rate == pytest.approx(0.5), \ - f"Expected kv_cache_hit_rate == 0.5, got {hit_rate}" - assert reused == 1.0, \ - f"Expected kv_cache_reused_blocks_total == 1.0, got {reused}" - assert missed == 1.0, \ - f"Expected kv_cache_missed_blocks_total == 1.0, got {missed}" + assert hit_rate > 0, \ + f"Expected kv_cache_hit_rate > 0, got {hit_rate}" + assert reused > 0, \ + f"Expected kv_cache_iter_reused_blocks_total > 0, got {reused}" + assert missed > 0, \ + f"Expected kv_cache_iter_missed_blocks_total > 0, got {missed}" assert utilization >= 0, \ f"Expected kv_cache_utilization >= 0, got {utilization}" + + assert METRIC_PREFIX + "kv_cache_hit_rate" in data + assert METRIC_PREFIX + "kv_cache_iter_reuse_rate" in data diff --git a/tests/unittest/metrics/test_collector.py b/tests/unittest/metrics/test_collector.py new file mode 100644 index 00000000000..c9f80fd2551 --- /dev/null +++ b/tests/unittest/metrics/test_collector.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for MetricsCollector.log_iteration_stats with kvCacheIterationStats.""" + +import pytest + +prometheus_client = pytest.importorskip("prometheus_client") + +from tensorrt_llm.metrics.collector import MetricsCollector # noqa: E402 + +# Use a single module-level collector to avoid re-registering Prometheus metrics +# (Prometheus does not allow duplicate metric names in the same process). +_collector = MetricsCollector(labels={"test": "true"}) + + +def _make_collector() -> MetricsCollector: + """Return the shared collector instance.""" + return _collector + + +def _get_gauge_value(collector, metric_name: str): + """Get the current value of a Prometheus gauge.""" + metric = getattr(collector, metric_name) + return metric.labels(**collector.labels)._value.get() + + +def _get_counter_value(collector, metric_name: str): + """Get the current value of a Prometheus counter.""" + metric = getattr(collector, metric_name) + return metric.labels(**collector.labels)._value.get() + + +class TestLogIterationStatsKvCacheIteration: + def test_no_kv_cache_iteration_stats(self): + """When kvCacheIterationStats is absent, new metrics should not error.""" + collector = _make_collector() + stats = {"kvCacheStats": {"cacheHitRate": 0.5, "usedNumBlocks": 10, "maxNumBlocks": 20}} + # Should not raise + collector.log_iteration_stats(stats) + + def test_gauges_updated(self): + """Host utilization and iter reuse rate gauges should be set.""" + collector = _make_collector() + stats = { + "kvCacheIterationStats": { + "16": { + "primaryMaxNumBlocks": 100, + "primaryFreeNumBlocks": 60, + "primaryUsedNumBlocks": 40, + "secondaryMaxNumBlocks": 50, + "secondaryFreeNumBlocks": 30, + "secondaryUsedNumBlocks": 20, + "iterAllocTotalBlocks": 8, + "iterAllocNewBlocks": 3, + "iterReusedBlocks": 5, + "iterFullReusedBlocks": 4, + "iterPartialReusedBlocks": 1, + "iterMissedBlocks": 3, + "iterCacheHitRate": 0.625, + "iterGenAllocBlocks": 2, + "iterOnboardBlocks": 1, + "iterOnboardBytes": 4096, + "iterOffloadBlocks": 0, + "iterOffloadBytes": 0, + } + } + } + collector.log_iteration_stats(stats) + + # Host utilization = 20/50 = 0.4 + assert _get_gauge_value(collector, "kv_cache_host_utilization") == pytest.approx(0.4) + # Iter reuse rate = 5/(5+3) = 0.625 + assert _get_gauge_value(collector, "kv_cache_iter_reuse_rate") == pytest.approx(0.625) + + def test_counters_incremented(self): + """Counter metrics should accumulate deltas across calls.""" + collector = _make_collector() + stats = { + "kvCacheIterationStats": { + "16": { + "primaryMaxNumBlocks": 100, + "primaryFreeNumBlocks": 60, + "primaryUsedNumBlocks": 40, + "secondaryMaxNumBlocks": 0, + "secondaryFreeNumBlocks": 0, + "secondaryUsedNumBlocks": 0, + "iterAllocTotalBlocks": 8, + "iterAllocNewBlocks": 3, + "iterReusedBlocks": 5, + "iterFullReusedBlocks": 4, + "iterPartialReusedBlocks": 1, + "iterMissedBlocks": 3, + "iterCacheHitRate": 0.625, + "iterGenAllocBlocks": 2, + "iterOnboardBlocks": 1, + "iterOnboardBytes": 4096, + "iterOffloadBlocks": 1, + "iterOffloadBytes": 2048, + "iterIntraDeviceCopyBlocks": 2, + "iterIntraDeviceCopyBytes": 8192, + } + } + } + + # Read baseline counter values (may be non-zero from prior tests) + before_reused = _get_counter_value(collector, "kv_cache_iter_reused_blocks") + before_missed = _get_counter_value(collector, "kv_cache_iter_missed_blocks") + before_onboard = _get_counter_value(collector, "kv_cache_onboard_bytes_total") + before_intra_device = _get_counter_value( + collector, "kv_cache_intra_device_copy_bytes_total" + ) + + # First call + collector.log_iteration_stats(stats) + assert _get_counter_value( + collector, "kv_cache_iter_reused_blocks" + ) - before_reused == pytest.approx(5) + assert _get_counter_value( + collector, "kv_cache_iter_missed_blocks" + ) - before_missed == pytest.approx(3) + assert _get_counter_value( + collector, "kv_cache_onboard_bytes_total" + ) - before_onboard == pytest.approx(4096) + assert _get_counter_value( + collector, "kv_cache_intra_device_copy_bytes_total" + ) - before_intra_device == pytest.approx(8192) + + # Second call — counters should accumulate further + collector.log_iteration_stats(stats) + assert _get_counter_value( + collector, "kv_cache_iter_reused_blocks" + ) - before_reused == pytest.approx(10) + assert _get_counter_value( + collector, "kv_cache_iter_missed_blocks" + ) - before_missed == pytest.approx(6) + assert _get_counter_value( + collector, "kv_cache_onboard_bytes_total" + ) - before_onboard == pytest.approx(8192) + assert _get_counter_value( + collector, "kv_cache_intra_device_copy_bytes_total" + ) - before_intra_device == pytest.approx(16384) + + def test_multiple_windows_aggregated(self): + """Stats from multiple window sizes should be summed.""" + collector = _make_collector() + ws16 = { + "primaryMaxNumBlocks": 50, + "primaryFreeNumBlocks": 30, + "primaryUsedNumBlocks": 20, + "secondaryMaxNumBlocks": 10, + "secondaryFreeNumBlocks": 5, + "secondaryUsedNumBlocks": 5, + "iterAllocTotalBlocks": 4, + "iterAllocNewBlocks": 2, + "iterReusedBlocks": 2, + "iterFullReusedBlocks": 2, + "iterPartialReusedBlocks": 0, + "iterMissedBlocks": 2, + "iterCacheHitRate": 0.5, + "iterGenAllocBlocks": 1, + "iterOnboardBlocks": 0, + "iterOnboardBytes": 0, + "iterOffloadBlocks": 0, + "iterOffloadBytes": 0, + } + ws64 = { + "primaryMaxNumBlocks": 50, + "primaryFreeNumBlocks": 40, + "primaryUsedNumBlocks": 10, + "secondaryMaxNumBlocks": 10, + "secondaryFreeNumBlocks": 2, + "secondaryUsedNumBlocks": 8, + "iterAllocTotalBlocks": 5, + "iterAllocNewBlocks": 2, + "iterReusedBlocks": 3, + "iterFullReusedBlocks": 1, + "iterPartialReusedBlocks": 2, + "iterMissedBlocks": 2, + "iterCacheHitRate": 0.6, + "iterGenAllocBlocks": 0, + "iterOnboardBlocks": 1, + "iterOnboardBytes": 8192, + "iterOffloadBlocks": 0, + "iterOffloadBytes": 0, + } + stats = {"kvCacheIterationStats": {"16": ws16, "64": ws64}} + + # Read baseline + before_reused = _get_counter_value(collector, "kv_cache_iter_reused_blocks") + + collector.log_iteration_stats(stats) + + # Host utilization = (5+8) / (10+10) = 13/20 = 0.65 + assert _get_gauge_value(collector, "kv_cache_host_utilization") == pytest.approx(0.65) + # Iter reuse rate = (2+3) / (2+3+2+2) = 5/9 + assert _get_gauge_value(collector, "kv_cache_iter_reuse_rate") == pytest.approx(5 / 9) + # Counters: delta should be 5 (2+3) + assert _get_counter_value( + collector, "kv_cache_iter_reused_blocks" + ) - before_reused == pytest.approx(5) + + def test_zero_deltas_no_counter_increment(self): + """When all deltas are zero, counters should not increment.""" + collector = _make_collector() + stats = { + "kvCacheIterationStats": { + "16": { + "primaryMaxNumBlocks": 100, + "primaryFreeNumBlocks": 100, + "primaryUsedNumBlocks": 0, + "secondaryMaxNumBlocks": 0, + "secondaryFreeNumBlocks": 0, + "secondaryUsedNumBlocks": 0, + "iterAllocTotalBlocks": 0, + "iterAllocNewBlocks": 0, + "iterReusedBlocks": 0, + "iterFullReusedBlocks": 0, + "iterPartialReusedBlocks": 0, + "iterMissedBlocks": 0, + "iterCacheHitRate": 0.0, + "iterGenAllocBlocks": 0, + "iterOnboardBlocks": 0, + "iterOnboardBytes": 0, + "iterOffloadBlocks": 0, + "iterOffloadBytes": 0, + } + } + } + before_reused = _get_counter_value(collector, "kv_cache_iter_reused_blocks") + before_missed = _get_counter_value(collector, "kv_cache_iter_missed_blocks") + collector.log_iteration_stats(stats) + assert _get_counter_value(collector, "kv_cache_iter_reused_blocks") == pytest.approx( + before_reused + ) + assert _get_counter_value(collector, "kv_cache_iter_missed_blocks") == pytest.approx( + before_missed + )