Skip to content

feat: enable chunk prefill#408

Open
jiayyu wants to merge 2 commits intomainfrom
chunk_prefill
Open

feat: enable chunk prefill#408
jiayyu wants to merge 2 commits intomainfrom
chunk_prefill

Conversation

@jiayyu
Copy link
Contributor

@jiayyu jiayyu commented Mar 25, 2026

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@jiayyu jiayyu closed this Mar 25, 2026
@jiayyu jiayyu reopened this Mar 25, 2026
@jiayyu jiayyu marked this pull request as ready for review March 25, 2026 08:41
Copilot AI review requested due to automatic review settings March 25, 2026 08:41
@jiayyu jiayyu closed this Mar 25, 2026
@jiayyu jiayyu reopened this Mar 25, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds chunked prefill support so prompt prefills can be split across multiple scheduler steps when constrained by max_num_batched_tokens, tracking KV-progress per sequence to resume prefills safely.

Changes:

  • Introduces Sequence.num_kv_computed and plumbs it through ScheduledBatch + attention metadata to support partial (chunked) prefills.
  • Updates the scheduler to (1) schedule partial prompt chunks within token budget and (2) resume incomplete prefills from running.
  • Adds an is_partial_prefill forward context flag to skip logits/sampling on intermediate chunks and updates scheduler tests accordingly.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
atom/model_engine/scheduler.py Implements chunked prefill scheduling/resume, tracks num_kv_computed, and updates postprocess behavior for partial prefills.
atom/model_engine/sequence.py Adds num_kv_computed state to sequences.
atom/model_engine/block_manager.py Initializes/resets num_kv_computed alongside prefix-cache allocation/deallocation.
atom/model_ops/attentions/backends.py Switches prefill metadata to use num_kv_computed and rebuilds slot mapping for chunked prefills.
atom/utils/forward_context.py Adds is_partial_prefill to forward context.
atom/model_engine/model_runner.py Skips logits + sampling for partial-prefill batches via Context.is_partial_prefill.
atom/model_engine/engine_core.py Passes scheduled_batch into scheduler postprocess() so KV-progress can be updated.
tests/test_scheduler.py Updates scheduler tests to reflect chunked prefill behavior and KV-progress simulation.
Comments suppressed due to low confidence (1)

atom/model_engine/scheduler.py:659

  • get_next_batch_info() is used by DP synchronization (see DPEngineCoreProc) to decide dummy prefill token counts. The waiting-queue branch returns the full remaining prompt length (seq.num_tokens - seq.num_cached_tokens) without chunking to max_num_batched_tokens, so DP ranks can overestimate and run overly large dummy prefills (potentially OOM / major slowdown). Make this consistent with the scheduler’s chunking logic (and ideally reflect prefix-cache hits, which aren’t known until allocation).
        if self.waiting:
            # new request is waiting, will do prefill
            seq = self.waiting[0]
            num_tokens = seq.num_tokens - seq.num_cached_tokens
            return (True, num_tokens)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 1557 to +1567
ctx_str = f"{ctx[:3].tolist()}...+{len(ctx)-3}"
label += f" tok={batch.total_tokens_num} ctx={ctx_str}"
label += "]"
with record_function(label):
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
if context.is_partial_prefill:
logits = (
None # B scheme: skip compute_logits for intermediate chunks
)
else:
logits = self.model.compute_logits(hidden_states)
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_model() can now return logits=None for partial prefills, but its return type annotation is tuple[torch.Tensor, torch.Tensor] and postprocess() accepts logits: torch.Tensor. This mismatch can break type checking and makes it easy to accidentally use logits before the early-return guard. Update the annotations to use Optional[torch.Tensor] (and propagate the optionality to postprocess()’s signature) to reflect the new control flow.

Copilot uses AI. Check for mistakes.
Comment on lines +354 to +368
# ---- Phase 1: resume partial prefills from running ----
for seq in self.running:
if seq.num_kv_computed >= seq.num_prompt_tokens:
continue # already completed prefill
remaining = seq.num_prompt_tokens - seq.num_kv_computed
budget_remaining = self.max_num_batched_tokens - num_batched_tokens
chunk = min(remaining, budget_remaining)
if chunk <= 0:
break
num_batched_tokens += chunk
num_seqs_prefill += 1
seq.type = SequenceType.PREFILL
scheduled_seqs[seq.id] = seq
num_scheduled_tokens.append(chunk)

Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Phase 1 (resuming partial prefills from self.running), scheduling doesn’t enforce self.max_num_seqs. If many running sequences are mid-prefill, num_seqs_prefill can exceed the configured max and create an oversized batch (risking buffer overruns / graph sizing assumptions downstream). Add a num_seqs_prefill < self.max_num_seqs guard (and break when reached) similar to Phase 2.

Copilot uses AI. Check for mistakes.
Comment on lines 179 to +220
class ScheduledBatch:
def __init__(
self,
seqs: dict[int, Sequence],
num_scheduled_tokens: list[int],
total_tokens_num: int,
total_tokens_num_prefill: int = 0,
total_tokens_num_decode: int = 0,
total_seqs_num: int = 0,
total_seqs_num_prefill: int = 0,
total_seqs_num_decode: int = 0,
is_dummy_run: bool = False,
num_spec_step: int = 0,
scheduled_spec_decode_tokens: dict[int, np.ndarray] = {},
num_kv_computed: list[int] = None,
is_partial_prefill: bool = False,
):
# len(seqs) == total_seqs_num == total_seqs_num_prefill + total_seqs_num_decode
# self.seqs = seqs
self.req_ids = list(seqs.keys())
# self.scheduled_tokens = [
# seq.token_ids[-num_tokens:]
# for seq, num_tokens in zip(seqs.values(), num_scheduled_tokens)
# ]
# logger.info(f"{num_scheduled_tokens=}")
# logger.info(f"{self.scheduled_tokens=}")
# num_scheduled_tokens for each sequence in the batch
self.num_scheduled_tokens = np.asarray(num_scheduled_tokens, dtype=np.int32)
self.temperatures = np.asarray(
[seq.temperature for seq in seqs.values()], dtype=np.float32
)
self.context_lens = np.asarray(
[seq.num_tokens for seq in seqs.values()], dtype=np.int32
)
self.num_rejected = np.asarray(
[seq.num_rejected for seq in seqs.values()], dtype=np.int32
)
self.num_bonus = np.asarray(
[seq.num_bonus_tokens for seq in seqs.values()], dtype=np.int32
)
self.mamba_block_tables = [
seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table
]
self.top_ks = np.asarray([seq.top_k for seq in seqs.values()], dtype=np.int32)
self.top_ps = np.asarray([seq.top_p for seq in seqs.values()], dtype=np.float32)

offs = self.context_lens - self.num_rejected - self.num_scheduled_tokens
# num_kv_computed for chunked prefill support
self.num_kv_computed = num_kv_computed or [
seq.num_kv_computed for seq in seqs.values()
]
self.is_partial_prefill = is_partial_prefill
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScheduledBatch.__init__ uses a mutable default (scheduled_spec_decode_tokens={}), which can leak state across instances. Also self.num_kv_computed = num_kv_computed or [...] will behave incorrectly for falsy-but-valid inputs (e.g., an empty list) and will raise for numpy arrays due to ambiguous truthiness. Use None defaults and explicit is None checks when populating these fields.

Copilot uses AI. Check for mistakes.
Comment on lines +488 to +494
for i, req_id in enumerate(batch.req_ids):
for seq in self.running:
if seq.id == req_id and seq.type == SequenceType.PREFILL:
seq.num_kv_computed += batch.num_scheduled_tokens[i]
if seq.num_kv_computed < seq.num_prompt_tokens:
partial_prefill_ids.add(seq.id)
break
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The num_kv_computed update block does an O(batch_size * running_size) nested scan (for req_id then for seq in self.running). At larger concurrencies this can become a noticeable per-step overhead. Consider building an id -> seq map for self.running once per call (or iterating running once and looking up scheduled ids) before updating num_kv_computed.

Suggested change
for i, req_id in enumerate(batch.req_ids):
for seq in self.running:
if seq.id == req_id and seq.type == SequenceType.PREFILL:
seq.num_kv_computed += batch.num_scheduled_tokens[i]
if seq.num_kv_computed < seq.num_prompt_tokens:
partial_prefill_ids.add(seq.id)
break
# Build a map from sequence id to running prefill sequence once to avoid
# an O(batch_size * running_size) nested scan.
id_to_prefill_seq: dict[int, Sequence] = {
seq.id: seq
for seq in self.running
if seq.type == SequenceType.PREFILL
}
for i, req_id in enumerate(batch.req_ids):
seq = id_to_prefill_seq.get(req_id)
if seq is None:
continue
seq.num_kv_computed += batch.num_scheduled_tokens[i]
if seq.num_kv_computed < seq.num_prompt_tokens:
partial_prefill_ids.add(seq.id)

Copilot uses AI. Check for mistakes.
Comment on lines +174 to +177
for pos in range(cached_seqlen, seqlen):
blk_idx = pos // block_size
slot = block_table[blk_idx] * block_size + pos % block_size
slot_mapping.append(slot)
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_prefill() now builds slot_mapping token-by-token (for pos in range(...)), which is O(num_scheduled_tokens) Python-side work. For long prompts / large batches this can become a bottleneck compared to the previous block-based range construction. Consider generating per-block contiguous ranges and only falling back to per-token handling for the first/last partial blocks (or otherwise vectorizing) to keep metadata building overhead low.

Suggested change
for pos in range(cached_seqlen, seqlen):
blk_idx = pos // block_size
slot = block_table[blk_idx] * block_size + pos % block_size
slot_mapping.append(slot)
# Build slot_mapping in a block-wise, vectorized manner to avoid
# per-token Python overhead for long prompts / large batches.
if seqlen > cached_seqlen:
start_block_idx = cached_seqlen // block_size
end_block_idx = (seqlen - 1) // block_size
for blk_idx in range(start_block_idx, end_block_idx + 1):
blk_id = block_table[blk_idx]
block_start_pos = max(cached_seqlen, blk_idx * block_size)
block_end_pos = min(seqlen, (blk_idx + 1) * block_size)
length = block_end_pos - block_start_pos
if length <= 0:
continue
block_offset_start = block_start_pos - blk_idx * block_size
start_slot = blk_id * block_size + block_offset_start
# Slots for this block form a contiguous range; generate
# them via torch.arange to keep work in optimized code.
slots = torch.arange(
start_slot,
start_slot + length,
dtype=torch.int64,
).tolist()
slot_mapping.extend(slots)

Copilot uses AI. Check for mistakes.
assert batch.total_seqs_num_prefill == 2
assert batch.total_tokens_num_prefill == 6
assert list(batch.num_scheduled_tokens) == [4, 2]

Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chunked prefill introduces new behavior (resume partial prefill across steps, num_kv_computed progression, and ensuring decode doesn’t start until prefill completes), but the updated tests only cover the initial chunk scheduling. Add a test that schedules a partial chunk, calls postprocess(..., batch=scheduled_batch) to advance num_kv_computed, then re-schedules until prompt completion and verifies the next step is decode (and that no completion token is emitted for intermediate chunks).

Suggested change
def test_chunked_prefill_progression_and_decode_start(self, seq_factory):
# Configure a scheduler that forces chunked prefill across multiple steps.
sched = Scheduler(
MockConfig(max_num_batched_tokens=6, num_kvcache_blocks=100)
)
# Two 4-token prompts; second one will be chunked (4 + 2 tokens in the first step).
s1 = seq_factory([1, 2, 3, 4])
s2 = seq_factory([5, 6, 7, 8])
sched.add(s1)
sched.add(s2)
# First scheduling step: partial prefill for s2.
batch, scheduled_seqs = sched.schedule()
assert batch.total_seqs_num_prefill == 2
assert batch.total_tokens_num_prefill == 6
assert list(batch.num_scheduled_tokens) == [4, 2]
# No decode should be scheduled while prompts are still being prefetched.
assert batch.total_seqs_num_decode == 0
# Simulate model forward pass for this chunk by advancing num_kv_computed.
for seq, num_tokens in zip(scheduled_seqs, batch.num_scheduled_tokens):
seq.num_kv_computed += num_tokens
# Call postprocess to advance scheduler state for this batch.
# We don't depend on the contents of the output object here, only that
# postprocess is invoked with the scheduled batch.
class _DummyOutput:
pass
sched.postprocess(_DummyOutput(), batch=batch)
# Continue scheduling until both prompts are fully prefetched.
while any(
seq.num_kv_computed < seq.num_prompt_tokens for seq in (s1, s2)
):
batch, scheduled_seqs = sched.schedule()
# Still in prefill phase; no decode yet.
assert batch.total_seqs_num_decode == 0
assert batch.total_seqs_num_prefill > 0
for seq, num_tokens in zip(scheduled_seqs, batch.num_scheduled_tokens):
seq.num_kv_computed += num_tokens
sched.postprocess(_DummyOutput(), batch=batch)
# Once prefill is complete for all sequences, the next step should be decode.
decode_batch, _ = sched.schedule()
assert decode_batch is not None
assert decode_batch.total_seqs_num_decode > 0

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants