Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions .claude/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Custom Claude Code Skills & Agents for TensorRT-LLM

## Background: Skills & agents in Claude Code

Claude Code supports two extensibility mechanisms — **skills** and **agents** —
that let teams encode domain expertise into reusable, version-controlled
components.

**Skills** are markdown playbooks that Claude follows step-by-step when
triggered. They are invoked via `/slash-commands` (e.g. `/perf-analysis`) or
matched automatically from natural-language requests. Each skill lives in its
own directory under `.claude/skills/` and can bundle reference materials that
Claude reads during execution. See
[Custom slash commands](https://code.claude.com/docs/en/skills)
for details.

**Agents** (sub-agents) are specialist workers that Claude spawns in a separate
context to handle focused tasks. Each agent has its own system prompt, tool
access, and domain knowledge. Claude delegates to them when it determines a task
fits a specialist's scope, while you can also invoke agents directly. Agent
definitions live under `.claude/agents/`. See
[Custom sub-agents](https://code.claude.com/docs/en/sub-agents)
for details.

## How skills and agents are loaded

For users who are working with Claude Code under TensorRT-LLM project directory,
skills and agents are automatically discovered by Claude Code at startup — no
manual registration needed. Files placed in `.claude/skills/` and
`.claude/agents/` are picked up by convention.

To verify what's loaded, launch Claude Code under TensorRT-LLM project directory
and type `/skills` or `/agents` in the Claude Code prompt to see available
skills and sub-agents.

## How to use skills and agents

There are two ways to trigger skills and agents:

1. **Automatic dispatch** — just describe what you need in plain language
(e.g. "profile this workload", "compile TensorRT-LLM"). Claude Code will
match your request to the appropriate skill or delegate to the right
sub-agent automatically.

2. **Manual invoke** — type `/<skill-name>` (e.g. `/perf-analysis`,
`/serve-config-guide`) to explicitly run a skill. For sub-agents, type
`@"<agent-name>" (agent)` (e.g. `@"exec-compile-specialist (agent)"`) to
delegate a task directly. This is useful when you know exactly which workflow you want.

In most cases, automatic dispatch is sufficient — you don't need to memorize
skill or agent names. Manual invoke is there for when you want precise control.

References:
* [Extend Claude with skills](https://code.claude.com/docs/en/skills)
* [Work with subagents](https://code.claude.com/docs/en/sub-agents#work-with-subagents)

## Naming convention

Every skill and agent name uses the format `<prefix>-<descriptive-name>`.
The prefix identifies the primary work area; the descriptive part should be
short and not repeat it.

| Prefix | Domain | Definition |
|---|---|---|
| `ad-` | AutoDeploy | Model onboarding, pipeline debugging, and execution for the AutoDeploy backend |
| `ci-` | CI/CD | CI failure retrieval, test diagnostics, and pipeline workflows |
| `exec-` | Execution infra | Environment setup and job execution (compile, run, container) |
| `kernel-` | Kernel development | Kernel writing, generation, and kernel-specific transforms |
| `perf-` | Performance work | Profiling, analysis, and tuning above the kernel layer (kernel modifications belong under `kernel-`) |
| `serve-` | Serving | Serving configuration, deployment, and runtime workflows |
| `trtllm-` | TRT-LLM dev workflows | Workflows for reading, modifying, and contributing to the codebase (static subsystem knowledge belongs in repo docs) |

Guidelines:

* If a skill doesn't fit any prefix, propose a new one and agree on its
boundary before using it.
* Use the prefix of the skill's **primary** domain, even if it orchestrates
across multiple domains.
* Agents follow the same convention.
* Good: `exec-local-compile`, `kernel-cuda-writing`, `perf-host-analysis`
* Bad: `exec-trtllm-compile`, `kernel-cuda-kernel-writing`,
`perf-trtllm-host-analysis`
87 changes: 83 additions & 4 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,42 @@ struct KvCacheStats
std::size_t allocatedBytes{};
};

/// @brief Per-iteration KV cache statistics. All delta counters represent changes since the last call to
/// getIterationStats(). Gauges are instantaneous snapshots.
struct KvCacheIterationStats
{
// --- Instantaneous gauges ---
// Primary (GPU) pool
SizeType32 primaryMaxNumBlocks{0};
SizeType32 primaryFreeNumBlocks{0};
SizeType32 primaryUsedNumBlocks{0};
// Secondary (host) pool
SizeType32 secondaryMaxNumBlocks{0};
SizeType32 secondaryFreeNumBlocks{0};
SizeType32 secondaryUsedNumBlocks{0};

// --- Per-iteration deltas (reset on each read) ---
// Context phase: block allocation and reuse
SizeType32 iterAllocTotalBlocks{0};
SizeType32 iterAllocNewBlocks{0};
SizeType32 iterReusedBlocks{0}; // = iterFullReusedBlocks + iterPartialReusedBlocks
SizeType32 iterFullReusedBlocks{0}; // blocks fully matched in radix tree
SizeType32 iterPartialReusedBlocks{0}; // blocks partially matched in radix tree
SizeType32 iterMissedBlocks{0};
float iterCacheHitRate{0.0f};
// Generation phase: block allocation
SizeType32 iterGenAllocBlocks{0};

// Transfer traffic deltas — host ↔ GPU
SizeType32 iterOnboardBlocks{0};
std::size_t iterOnboardBytes{0};
SizeType32 iterOffloadBlocks{0};
std::size_t iterOffloadBytes{0};
// Intra-device (GPU → GPU) block copies (e.g. partial reuse when source block has refs)
SizeType32 iterIntraDeviceCopyBlocks{0};
std::size_t iterIntraDeviceCopyBytes{0};
};

// Basic building block of a paged KV cache - a single
// cache block. This class just holds metadata, no pointers
// since it is reused across all layers.
Expand Down Expand Up @@ -815,6 +851,12 @@ class WindowBlockManager
return mMissedBlocks;
}

// Get num free blocks in the secondary (host) memory pool
[[nodiscard]] SizeType32 getNumFreeSecondaryBlocks() const noexcept;

//! \brief Get iteration stats (deltas since last call) for this window. Resets internal delta snapshots.
[[nodiscard]] KvCacheIterationStats getAndResetIterationStats();

[[nodiscard]] bool hasFreeBlocks(SizeType32 numRequired = 1) const
{
return getNumFreeBlocks() >= numRequired;
Expand Down Expand Up @@ -1128,16 +1170,22 @@ class WindowBlockManager
std::shared_ptr<KVCacheTransferManager> mTransferManager;

// Statistics for block allocations/reuse
// Total number of blocks allocated by all requests
// Total number of blocks allocated by all requests (context phase)
SizeType32 mAllocTotalBlocks;
// Number of new blocks that were allocated
// Number of new blocks that were allocated (context phase)
SizeType32 mAllocNewBlocks;
// Number of blocks that were reused
// Number of blocks that were fully reused (context phase)
SizeType32 mFullReusedBlocks;
// Number of blocks that were partially reused (context phase)
SizeType32 mPartialReusedBlocks;
// Number of blocks that were reused (full + partial, context phase)
SizeType32 mReusedBlocks;
// Number of unique blocks that were reused
SizeType32 mReusedUniqueBlocks;
// Number of blocks that were not reused
// Number of blocks that were not reused (context phase)
SizeType32 mMissedBlocks;
// Number of blocks allocated during generation phase
SizeType32 mGenAllocBlocks;
// Only be 1 or 2. If 2: general KV stored. If 1: K == V for any token, so only K is stored to optimize the
// max_num_tokens(For DeepSeek). Controlled by mCacheType
SizeType32 mKVFactor;
Expand All @@ -1154,6 +1202,15 @@ class WindowBlockManager
// The kv cache connector manager
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;

// Snapshot of cumulative counters at last iteration stats read (for delta computation)
SizeType32 mPrevAllocTotalBlocks{0};
SizeType32 mPrevAllocNewBlocks{0};
SizeType32 mPrevReusedBlocks{0};
SizeType32 mPrevFullReusedBlocks{0};
SizeType32 mPrevPartialReusedBlocks{0};
SizeType32 mPrevMissedBlocks{0};
SizeType32 mPrevGenAllocBlocks{0};

// Mutex for the cached blocks root
mutable std::mutex mCachedBlocksRootMutex;

Expand Down Expand Up @@ -1359,6 +1416,19 @@ class BlockManager
return sumWindows([](auto const& manager) { return manager.getNumMissedBlocks(); });
}

[[nodiscard]] SizeType32 getNumSecondaryBlocks() const
{
return sumWindows([](auto const& manager) { return manager.getNumSecondaryBlocks(); });
}

[[nodiscard]] SizeType32 getNumFreeSecondaryBlocks() const
{
return sumWindows([](auto const& manager) { return manager.getNumFreeSecondaryBlocks(); });
}

//! \brief Get per-window-size iteration stats. Resets delta snapshots for each window.
[[nodiscard]] std::map<SizeType32, KvCacheIterationStats> getAndResetIterationStats();

[[nodiscard]] SizeType32 getNumLayers() const
{
return mNumLayers;
Expand Down Expand Up @@ -1688,6 +1758,10 @@ class BaseKVCacheManager

[[nodiscard]] virtual KvCacheStats getKvCacheStats() const = 0;

//! \brief Get per-iteration stats with delta counters, keyed by window size.
//! Resets delta snapshots on each call.
[[nodiscard]] virtual std::map<SizeType32, KvCacheIterationStats> getIterationStats() = 0;

[[nodiscard]] virtual OffsetTableDimensions getOffsetTableDimensions() const = 0;

[[nodiscard]] virtual std::deque<executor::KVCacheEvent> getLatestEvents(
Expand Down Expand Up @@ -2046,6 +2120,11 @@ class KVCacheManager : public BaseKVCacheManager
return kvCacheStats;
}

[[nodiscard]] std::map<SizeType32, KvCacheIterationStats> getIterationStats() override
{
return mBlockManager.getAndResetIterationStats();
}

[[nodiscard]] OffsetTableDimensions getOffsetTableDimensions() const override
{
OffsetTableDimensions dims;
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ namespace kvc = tensorrt_llm::executor::kv_cache;
namespace tensorrt_llm::batch_manager::kv_cache_manager
{

/// @brief Statistics for block transfers. Returned by KVCacheTransferManager::getAndResetTransferStats().
/// All counters are reset on read.
/// - onboard/offload: transfers between secondary (host) and primary (GPU) memory.
/// - intraDeviceCopy: GPU-to-GPU block copies (e.g. partial reuse when source block has refs).
struct KvCacheTransferStats
{
SizeType32 onboardBlocks{0};
std::size_t onboardBytes{0};
SizeType32 offloadBlocks{0};
std::size_t offloadBytes{0};
SizeType32 intraDeviceCopyBlocks{0};
std::size_t intraDeviceCopyBytes{0};
};

// The TransferManager accelerates transfers to/from the GPU by overlapping HtoD and DtoH transfers, and tracks ongoing
// transfers in order to avoid race conditions. It is functionally equivalent to the prior approach of putting all
// transfers into the forward pass stream. This is only ever used as a component of a KVCacheManager.
Expand Down Expand Up @@ -57,6 +71,9 @@ class KVCacheTransferManager
//! must be called after last call to KVCacheManager::addSequence in every step.
void syncTransfers();

//! \brief Get transfer stats accumulated since last call, and reset the counters.
[[nodiscard]] KvCacheTransferStats getAndResetTransferStats();

private:
//! \brief Get pointer to pool specified by cache block.
static tr::ITensor::SharedPtr computeBlockPointer(
Expand All @@ -79,6 +96,12 @@ class KVCacheTransferManager
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::string const& directory = "");

//! \brief Compute total bytes actually transferred for a block copy across all pools.
//! \param pools The pool descriptors.
//! \param numTokensToCopy Number of tokens for partial copy (0 means full block).
[[nodiscard]] std::size_t computeBlockTransferBytes(
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy) const;

runtime::BufferManager mBufferManager;
runtime::BufferManager mOnboardManager;
runtime::BufferManager mOffloadManager;
Expand All @@ -90,6 +113,16 @@ class KVCacheTransferManager
// Reference to parent loopback agent
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
int mDeviceId;

// Cumulative transfer statistics, reset on each call to getAndResetTransferStats().
// Protected by mStatsMutex for thread-safe access.
mutable std::mutex mStatsMutex;
SizeType32 mOnboardBlockCount{0};
std::size_t mOnboardByteCount{0};
SizeType32 mOffloadBlockCount{0};
std::size_t mOffloadByteCount{0};
SizeType32 mIntraDeviceCopyBlockCount{0};
std::size_t mIntraDeviceCopyByteCount{0};
};

} // namespace tensorrt_llm::batch_manager::kv_cache_manager
79 changes: 79 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,9 +760,12 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
, mTransferManager{std::make_shared<KVCacheTransferManager>(mBufferManager, mLoopbackAgent)}
, mAllocTotalBlocks{0}
, mAllocNewBlocks{0}
, mFullReusedBlocks{0}
, mPartialReusedBlocks{0}
, mReusedBlocks{0}
, mReusedUniqueBlocks{0}
, mMissedBlocks{0}
, mGenAllocBlocks{0}
, mKVFactor{(mCacheType == CacheType::kSELFKONLY
|| (linearAttentionMetadata.has_value() && linearAttentionMetadata->hasRecurrentStatesCache()))
? 1
Expand Down Expand Up @@ -1518,6 +1521,14 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
reusedBlockIds.insert(matchingBlockId);
++mReusedUniqueBlocks;
}
if (partialMatch)
{
++mPartialReusedBlocks;
}
else
{
++mFullReusedBlocks;
}
}
++blockItr;
}
Expand Down Expand Up @@ -1726,6 +1737,7 @@ void WindowBlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence)
{
// Allocating a new block when the last token is a block boundary
allocateBlock(sequence, /*shareAmongBeams=*/sequence.getBeamWidth() == 1);
++mGenAllocBlocks;
updateLastCacheBlockOffsets(sequence);
}
}
Expand Down Expand Up @@ -2226,6 +2238,73 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence)
return numFree;
}

[[nodiscard]] SizeType32 WindowBlockManager::getNumFreeSecondaryBlocks() const noexcept
{
return mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel);
}

KvCacheIterationStats WindowBlockManager::getAndResetIterationStats()
{
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
KvCacheIterationStats stats;

// Instantaneous gauges
stats.primaryMaxNumBlocks = getNumPrimaryBlocks();
stats.primaryFreeNumBlocks = getNumFreeBlocks();
stats.primaryUsedNumBlocks = stats.primaryMaxNumBlocks - stats.primaryFreeNumBlocks;
stats.secondaryMaxNumBlocks = getNumSecondaryBlocks();
stats.secondaryFreeNumBlocks = getNumFreeSecondaryBlocks();
stats.secondaryUsedNumBlocks = stats.secondaryMaxNumBlocks - stats.secondaryFreeNumBlocks;

// Compute deltas since last call — context phase
stats.iterAllocTotalBlocks = mAllocTotalBlocks - mPrevAllocTotalBlocks;
stats.iterAllocNewBlocks = mAllocNewBlocks - mPrevAllocNewBlocks;
stats.iterReusedBlocks = mReusedBlocks - mPrevReusedBlocks;
stats.iterFullReusedBlocks = mFullReusedBlocks - mPrevFullReusedBlocks;
stats.iterPartialReusedBlocks = mPartialReusedBlocks - mPrevPartialReusedBlocks;
stats.iterMissedBlocks = mMissedBlocks - mPrevMissedBlocks;

auto const iterTotal = stats.iterReusedBlocks + stats.iterMissedBlocks;
stats.iterCacheHitRate
= iterTotal == 0 ? 0.0f : static_cast<float>(stats.iterReusedBlocks) / static_cast<float>(iterTotal);

// Generation phase
stats.iterGenAllocBlocks = mGenAllocBlocks - mPrevGenAllocBlocks;

// Snapshot current values for next delta
mPrevAllocTotalBlocks = mAllocTotalBlocks;
mPrevAllocNewBlocks = mAllocNewBlocks;
mPrevReusedBlocks = mReusedBlocks;
mPrevFullReusedBlocks = mFullReusedBlocks;
mPrevPartialReusedBlocks = mPartialReusedBlocks;
mPrevMissedBlocks = mMissedBlocks;
mPrevGenAllocBlocks = mGenAllocBlocks;

// Transfer stats (collected from transfer manager)
if (mTransferManager)
{
auto transferStats = mTransferManager->getAndResetTransferStats();
stats.iterOnboardBlocks = transferStats.onboardBlocks;
stats.iterOnboardBytes = transferStats.onboardBytes;
stats.iterOffloadBlocks = transferStats.offloadBlocks;
stats.iterOffloadBytes = transferStats.offloadBytes;
stats.iterIntraDeviceCopyBlocks = transferStats.intraDeviceCopyBlocks;
stats.iterIntraDeviceCopyBytes = transferStats.intraDeviceCopyBytes;
}

return stats;
}

std::map<SizeType32, KvCacheIterationStats> BlockManager::getAndResetIterationStats()
{
std::map<SizeType32, KvCacheIterationStats> perWindowStats;
for (auto& [windowSize, manager] : mWindowBlockManagers)
{
perWindowStats[windowSize] = manager.getAndResetIterationStats();
}
return perWindowStats;
}

std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::chrono::milliseconds> timeout) const
{
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
Expand Down
Loading