Skip to content

KV cache benchmark: multi-turn full context reload, block-layer traci…#290

Merged
dslik merged 1 commit intomlcommons:mainfrom
hazemawadalla:main
Mar 30, 2026
Merged

KV cache benchmark: multi-turn full context reload, block-layer traci…#290
dslik merged 1 commit intomlcommons:mainfrom
hazemawadalla:main

Conversation

@hazemawadalla
Copy link
Copy Markdown
Contributor

…ng & fio workload distiller

Problem:
The multi-turn conversation path in process_requests() only read the
immediately previous turn (turn N-1) when resuming a conversation. In
a real inference serving environment with KV cache offloading, resuming
a conversation requires reloading the full prior context from storage;
every previous turn that survived LRU eviction needs to be read back.
Reading only N-1 understated the read I/O by a factor of N/2 for deep
conversations. Additionally, the decode probe read latency in step 5
was silently dropped from the per-request storage_latency accumulator;
every other read in process_requests() accumulated correctly except
this one, and decode_latencies only recorded the probe read without
the batched decode reads that follow it.

Separately, we had no way to decompose what was happening at the block
layer during benchmark runs. The L4 "device" latency measures the time
to read an entire .npy file through np.load(); these files can reach
500 MB to 2 GB depending on context length and model, and the kernel
splits each read into hundreds of NVMe commands at the MDTS boundary.
The P95 device read latency reflects the total time to load a large
KV cache entry. This adds world class telemetry with minimal overhead
on the storage block layer metrics enabled via a single flag
--enable-latency-tracing.

Fixes (benchmark.py):

  • Multi-turn step 2 now calls get_all_previous_turn_keys() and reads ALL previous turns via access_cache(). Entries that were evicted by the LRU waterfall return (None, 0.0) immediately with zero I/O and zero memory allocation. Surviving entries get real np.load reads with measured latency. The multi-turn hit rate in the output now reflects the true conversation cache survival rate under eviction pressure; we saw 45% hit rate with 10 users on DC3000ME which tells you exactly how much prior context the storage tier can sustain.
  • Steps 2+3 moved inside if not self.decode_only guard; in decode-only mode writes are skipped so multi-turn reads always miss.
  • storage_latency += read_latency added after the step 5 decode probe read, matching every other read in the method.
  • decode_latencies now accumulates probe read + all batched decode reads per request, not just the probe.
  • max_turns_per_conv hard cap enforced in user_worker; previously the config value was read but never checked, so conversations could grow unbounded regardless of the setting.
  • Memory safety check at startup: estimates peak RAM from the formula peak = (workers x 2 x mean_entry_bytes) + baseline and warns with safe --num-users / --max-concurrent-allocs values if the estimate exceeds 85% of available RAM.

Block-layer tracing (--enable-latency-tracing):

  • Spawns bpftrace as a sudo subprocess before the benchmark run, sends SIGINT after, parses the histogram output into structured data.
  • 15 histograms captured: D2C read/write (actual NVMe hardware time per command), Q2D read/write (I/O scheduler queue), VFS read/write (application visible), fsync, write-to-fsync serialization gap, fadvise-to-read gap, block size distribution (bssplit) read/write, in-flight I/O count at dispatch read/write, and LBA heatmap read/write (10 GB linear buckets via lhist).
  • bpftrace 0.14 compatible: uses comm == instead of str(comm) ==, END block removed (bpftrace auto-prints maps on SIGINT), D2C measured unconditionally at block_rq_issue (not gated on block_rq_insert which NVMe blk-mq direct dispatch bypasses).
  • Results flow to stdout (P50/P95/P99 per histogram with raw bars), JSON (full bucket data under device_latency_tracing key), and XLSX (Device Tracing summary sheet + Trace Histograms raw data sheet).

fio workload distiller:

  • generate_fio_workload() in benchmark.py distills the traced bssplit, read/write ratio, queue depth, and thinktime into a standalone fio .ini file saved as fio_kv_cache_workload{timestamp}.ini.
  • utils/distill_fio.py is a standalone Python script that parses raw bpftrace output and generates the same fio config without needing the benchmark installed. Point it at vLLM, llm-d, or any process and get a representative fio workload for bare-metal drive comparison.
  • utils/storage_latency_stack.sh updated with --fio flag that captures trace output and pipes through distill_fio.py automatically.
  • bssplit derived from block size histograms with separate read/write splits per fio spec. rwmixread from I/O count ratio. iodepth from in-flight I/O histogram P50. thinktime from write-to-fsync gap P50. thinktime_iotime commented out for fio <3.28 compatibility.
  • Generated config includes D2C latency summary and LBA hot zone in the header comments for reference.

New method (cache.py):

  • check_cache_exists(): metadata-only existence check on cache_entries. Returns (location, size) or (None, 0). No np.load, no LRU update, no memory allocation.

Tests:

  • test_part3c updated: shows the triangular read pattern where turn 4 reads turns 1, 2, 3. Asserts 10 total decode_reads (6 multi-turn reads + 4 decode reads).
  • test_part3d added: 6-turn conversation on NVMe with capacity for 3.5 entries. Proves eviction works correctly; later turns show a mix of hits (recent turns still cached) and misses (old turns evicted). Documents both miss counters: cache.stats['cache_misses'] (tier-level, all types) vs results['multi_turn_cache_misses'] (benchmark-level, step 2 only).

Documentation:

  • README.md: Added Block-Layer Latency Tracing section with usage examples, histogram reference table, fio distiller docs, and standalone vLLM/llm-d tracing instructions. Updated Unit Testing section with TestVisualizeUserRequestFlow examples and flags.
  • Proposal document: Added section 8.5 covering tracing motivation, histogram reference, fio distiller, and standalone usage.

Tested on Kingston DC3000ME PCIe Gen5 NVMe (447 GB), llama3.1-8b, 1-100 users, seed 42, with bpftrace block-layer validation confirming D2C read P50 of 1-2 ms per NVMe command and bssplit dominated by 1 MB blocks (99% of read I/Os, matching MDTS splits of the large KV cache .npy files).

…ng & fio workload distiller

Problem:
  The multi-turn conversation path in process_requests() only read the
  immediately previous turn (turn N-1) when resuming a conversation. In
  a real inference serving environment with KV cache offloading, resuming
  a conversation requires reloading the full prior context from storage;
  every previous turn that survived LRU eviction needs to be read back.
  Reading only N-1 understated the read I/O by a factor of N/2 for deep
  conversations. Additionally, the decode probe read latency in step 5
  was silently dropped from the per-request storage_latency accumulator;
  every other read in process_requests() accumulated correctly except
  this one, and decode_latencies only recorded the probe read without
  the batched decode reads that follow it.

  Separately, we had no way to decompose what was happening at the block
  layer during benchmark runs. The L4 "device" latency measures the time
  to read an entire .npy file through np.load(); these files can reach
  500 MB to 2 GB depending on context length and model, and the kernel
  splits each read into hundreds of NVMe commands at the MDTS boundary.
  The P95 device read latency reflects the total time to load a large
  KV cache entry. This adds world class telemetry with minimal overhead
  on the storage block layer metrics enabled via a single flag
  --enable-latency-tracing.

Fixes (benchmark.py):
  - Multi-turn step 2 now calls get_all_previous_turn_keys() and reads
    ALL previous turns via access_cache(). Entries that were evicted by
    the LRU waterfall return (None, 0.0) immediately with zero I/O and
    zero memory allocation. Surviving entries get real np.load reads
    with measured latency. The multi-turn hit rate in the output now
    reflects the true conversation cache survival rate under eviction
    pressure; we saw 45% hit rate with 10 users on DC3000ME which tells
    you exactly how much prior context the storage tier can sustain.
  - Steps 2+3 moved inside if not self.decode_only guard; in decode-only
    mode writes are skipped so multi-turn reads always miss.
  - storage_latency += read_latency added after the step 5 decode probe
    read, matching every other read in the method.
  - decode_latencies now accumulates probe read + all batched decode
    reads per request, not just the probe.
  - max_turns_per_conv hard cap enforced in user_worker; previously the
    config value was read but never checked, so conversations could grow
    unbounded regardless of the setting.
  - Memory safety check at startup: estimates peak RAM from the formula
    peak = (workers x 2 x mean_entry_bytes) + baseline and warns with
    safe --num-users / --max-concurrent-allocs values if the estimate
    exceeds 85% of available RAM.

Block-layer tracing (--enable-latency-tracing):
  - Spawns bpftrace as a sudo subprocess before the benchmark run,
    sends SIGINT after, parses the histogram output into structured data.
  - 15 histograms captured: D2C read/write (actual NVMe hardware time
    per command), Q2D read/write (I/O scheduler queue), VFS read/write
    (application visible), fsync, write-to-fsync serialization gap,
    fadvise-to-read gap, block size distribution (bssplit) read/write,
    in-flight I/O count at dispatch read/write, and LBA heatmap
    read/write (10 GB linear buckets via lhist).
  - bpftrace 0.14 compatible: uses comm == instead of str(comm) ==,
    END block removed (bpftrace auto-prints maps on SIGINT), D2C
    measured unconditionally at block_rq_issue (not gated on
    block_rq_insert which NVMe blk-mq direct dispatch bypasses).
  - Results flow to stdout (P50/P95/P99 per histogram with raw bars),
    JSON (full bucket data under device_latency_tracing key), and XLSX
    (Device Tracing summary sheet + Trace Histograms raw data sheet).

fio workload distiller:
  - _generate_fio_workload() in benchmark.py distills the traced bssplit,
    read/write ratio, queue depth, and thinktime into a standalone fio
    .ini file saved as fio_kv_cache_workload_{timestamp}.ini.
  - utils/distill_fio.py is a standalone Python script that parses raw
    bpftrace output and generates the same fio config without needing
    the benchmark installed. Point it at vLLM, llm-d, or any process
    and get a representative fio workload for bare-metal drive comparison.
  - utils/storage_latency_stack.sh updated with --fio flag that captures
    trace output and pipes through distill_fio.py automatically.
  - bssplit derived from block size histograms with separate read/write
    splits per fio spec. rwmixread from I/O count ratio. iodepth from
    in-flight I/O histogram P50. thinktime from write-to-fsync gap P50.
    thinktime_iotime commented out for fio <3.28 compatibility.
  - Generated config includes D2C latency summary and LBA hot zone
    in the header comments for reference.

New method (cache.py):
  - check_cache_exists(): metadata-only existence check on cache_entries.
    Returns (location, size) or (None, 0). No np.load, no LRU update,
    no memory allocation.

Tests:
  - test_part3c updated: shows the triangular read pattern where turn 4
    reads turns 1, 2, 3. Asserts 10 total decode_reads (6 multi-turn
    reads + 4 decode reads).
  - test_part3d added: 6-turn conversation on NVMe with capacity for
    3.5 entries. Proves eviction works correctly; later turns show a mix
    of hits (recent turns still cached) and misses (old turns evicted).
    Documents both miss counters: cache.stats['cache_misses'] (tier-level,
    all types) vs results['multi_turn_cache_misses'] (benchmark-level,
    step 2 only).

Documentation:
  - README.md: Added Block-Layer Latency Tracing section with usage
    examples, histogram reference table, fio distiller docs, and
    standalone vLLM/llm-d tracing instructions. Updated Unit Testing
    section with TestVisualizeUserRequestFlow examples and flags.
  - Proposal document: Added section 8.5 covering tracing motivation,
    histogram reference, fio distiller, and standalone usage.

Tested on Kingston DC3000ME PCIe Gen5 NVMe (447 GB), llama3.1-8b,
1-100 users, seed 42, with bpftrace block-layer validation confirming
D2C read P50 of 1-2 ms per NVMe command and bssplit dominated by
1 MB blocks (99% of read I/Os, matching MDTS splits of the large
KV cache .npy files).
@hazemawadalla hazemawadalla requested a review from a team March 26, 2026 08:35
@github-actions
Copy link
Copy Markdown

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

Copy link
Copy Markdown
Contributor

@dslik dslik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks!

@dslik dslik merged commit cc0fc51 into mlcommons:main Mar 30, 2026
2 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Mar 30, 2026
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants