Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds chunked prefill support so prompt prefills can be split across multiple scheduler steps when constrained by max_num_batched_tokens, tracking KV-progress per sequence to resume prefills safely.
Changes:
- Introduces
Sequence.num_kv_computedand plumbs it throughScheduledBatch+ attention metadata to support partial (chunked) prefills. - Updates the scheduler to (1) schedule partial prompt chunks within token budget and (2) resume incomplete prefills from
running. - Adds an
is_partial_prefillforward context flag to skip logits/sampling on intermediate chunks and updates scheduler tests accordingly.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
atom/model_engine/scheduler.py |
Implements chunked prefill scheduling/resume, tracks num_kv_computed, and updates postprocess behavior for partial prefills. |
atom/model_engine/sequence.py |
Adds num_kv_computed state to sequences. |
atom/model_engine/block_manager.py |
Initializes/resets num_kv_computed alongside prefix-cache allocation/deallocation. |
atom/model_ops/attentions/backends.py |
Switches prefill metadata to use num_kv_computed and rebuilds slot mapping for chunked prefills. |
atom/utils/forward_context.py |
Adds is_partial_prefill to forward context. |
atom/model_engine/model_runner.py |
Skips logits + sampling for partial-prefill batches via Context.is_partial_prefill. |
atom/model_engine/engine_core.py |
Passes scheduled_batch into scheduler postprocess() so KV-progress can be updated. |
tests/test_scheduler.py |
Updates scheduler tests to reflect chunked prefill behavior and KV-progress simulation. |
Comments suppressed due to low confidence (1)
atom/model_engine/scheduler.py:659
get_next_batch_info()is used by DP synchronization (seeDPEngineCoreProc) to decide dummy prefill token counts. The waiting-queue branch returns the full remaining prompt length (seq.num_tokens - seq.num_cached_tokens) without chunking tomax_num_batched_tokens, so DP ranks can overestimate and run overly large dummy prefills (potentially OOM / major slowdown). Make this consistent with the scheduler’s chunking logic (and ideally reflect prefix-cache hits, which aren’t known until allocation).
if self.waiting:
# new request is waiting, will do prefill
seq = self.waiting[0]
num_tokens = seq.num_tokens - seq.num_cached_tokens
return (True, num_tokens)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ctx_str = f"{ctx[:3].tolist()}...+{len(ctx)-3}" | ||
| label += f" tok={batch.total_tokens_num} ctx={ctx_str}" | ||
| label += "]" | ||
| with record_function(label): | ||
| hidden_states = self.model(input_ids, positions) | ||
| logits = self.model.compute_logits(hidden_states) | ||
| if context.is_partial_prefill: | ||
| logits = ( | ||
| None # B scheme: skip compute_logits for intermediate chunks | ||
| ) | ||
| else: | ||
| logits = self.model.compute_logits(hidden_states) |
There was a problem hiding this comment.
run_model() can now return logits=None for partial prefills, but its return type annotation is tuple[torch.Tensor, torch.Tensor] and postprocess() accepts logits: torch.Tensor. This mismatch can break type checking and makes it easy to accidentally use logits before the early-return guard. Update the annotations to use Optional[torch.Tensor] (and propagate the optionality to postprocess()’s signature) to reflect the new control flow.
| # ---- Phase 1: resume partial prefills from running ---- | ||
| for seq in self.running: | ||
| if seq.num_kv_computed >= seq.num_prompt_tokens: | ||
| continue # already completed prefill | ||
| remaining = seq.num_prompt_tokens - seq.num_kv_computed | ||
| budget_remaining = self.max_num_batched_tokens - num_batched_tokens | ||
| chunk = min(remaining, budget_remaining) | ||
| if chunk <= 0: | ||
| break | ||
| num_batched_tokens += chunk | ||
| num_seqs_prefill += 1 | ||
| seq.type = SequenceType.PREFILL | ||
| scheduled_seqs[seq.id] = seq | ||
| num_scheduled_tokens.append(chunk) | ||
|
|
There was a problem hiding this comment.
In Phase 1 (resuming partial prefills from self.running), scheduling doesn’t enforce self.max_num_seqs. If many running sequences are mid-prefill, num_seqs_prefill can exceed the configured max and create an oversized batch (risking buffer overruns / graph sizing assumptions downstream). Add a num_seqs_prefill < self.max_num_seqs guard (and break when reached) similar to Phase 2.
| class ScheduledBatch: | ||
| def __init__( | ||
| self, | ||
| seqs: dict[int, Sequence], | ||
| num_scheduled_tokens: list[int], | ||
| total_tokens_num: int, | ||
| total_tokens_num_prefill: int = 0, | ||
| total_tokens_num_decode: int = 0, | ||
| total_seqs_num: int = 0, | ||
| total_seqs_num_prefill: int = 0, | ||
| total_seqs_num_decode: int = 0, | ||
| is_dummy_run: bool = False, | ||
| num_spec_step: int = 0, | ||
| scheduled_spec_decode_tokens: dict[int, np.ndarray] = {}, | ||
| num_kv_computed: list[int] = None, | ||
| is_partial_prefill: bool = False, | ||
| ): | ||
| # len(seqs) == total_seqs_num == total_seqs_num_prefill + total_seqs_num_decode | ||
| # self.seqs = seqs | ||
| self.req_ids = list(seqs.keys()) | ||
| # self.scheduled_tokens = [ | ||
| # seq.token_ids[-num_tokens:] | ||
| # for seq, num_tokens in zip(seqs.values(), num_scheduled_tokens) | ||
| # ] | ||
| # logger.info(f"{num_scheduled_tokens=}") | ||
| # logger.info(f"{self.scheduled_tokens=}") | ||
| # num_scheduled_tokens for each sequence in the batch | ||
| self.num_scheduled_tokens = np.asarray(num_scheduled_tokens, dtype=np.int32) | ||
| self.temperatures = np.asarray( | ||
| [seq.temperature for seq in seqs.values()], dtype=np.float32 | ||
| ) | ||
| self.context_lens = np.asarray( | ||
| [seq.num_tokens for seq in seqs.values()], dtype=np.int32 | ||
| ) | ||
| self.num_rejected = np.asarray( | ||
| [seq.num_rejected for seq in seqs.values()], dtype=np.int32 | ||
| ) | ||
| self.num_bonus = np.asarray( | ||
| [seq.num_bonus_tokens for seq in seqs.values()], dtype=np.int32 | ||
| ) | ||
| self.mamba_block_tables = [ | ||
| seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table | ||
| ] | ||
| self.top_ks = np.asarray([seq.top_k for seq in seqs.values()], dtype=np.int32) | ||
| self.top_ps = np.asarray([seq.top_p for seq in seqs.values()], dtype=np.float32) | ||
|
|
||
| offs = self.context_lens - self.num_rejected - self.num_scheduled_tokens | ||
| # num_kv_computed for chunked prefill support | ||
| self.num_kv_computed = num_kv_computed or [ | ||
| seq.num_kv_computed for seq in seqs.values() | ||
| ] | ||
| self.is_partial_prefill = is_partial_prefill |
There was a problem hiding this comment.
ScheduledBatch.__init__ uses a mutable default (scheduled_spec_decode_tokens={}), which can leak state across instances. Also self.num_kv_computed = num_kv_computed or [...] will behave incorrectly for falsy-but-valid inputs (e.g., an empty list) and will raise for numpy arrays due to ambiguous truthiness. Use None defaults and explicit is None checks when populating these fields.
| for i, req_id in enumerate(batch.req_ids): | ||
| for seq in self.running: | ||
| if seq.id == req_id and seq.type == SequenceType.PREFILL: | ||
| seq.num_kv_computed += batch.num_scheduled_tokens[i] | ||
| if seq.num_kv_computed < seq.num_prompt_tokens: | ||
| partial_prefill_ids.add(seq.id) | ||
| break |
There was a problem hiding this comment.
The num_kv_computed update block does an O(batch_size * running_size) nested scan (for req_id then for seq in self.running). At larger concurrencies this can become a noticeable per-step overhead. Consider building an id -> seq map for self.running once per call (or iterating running once and looking up scheduled ids) before updating num_kv_computed.
| for i, req_id in enumerate(batch.req_ids): | |
| for seq in self.running: | |
| if seq.id == req_id and seq.type == SequenceType.PREFILL: | |
| seq.num_kv_computed += batch.num_scheduled_tokens[i] | |
| if seq.num_kv_computed < seq.num_prompt_tokens: | |
| partial_prefill_ids.add(seq.id) | |
| break | |
| # Build a map from sequence id to running prefill sequence once to avoid | |
| # an O(batch_size * running_size) nested scan. | |
| id_to_prefill_seq: dict[int, Sequence] = { | |
| seq.id: seq | |
| for seq in self.running | |
| if seq.type == SequenceType.PREFILL | |
| } | |
| for i, req_id in enumerate(batch.req_ids): | |
| seq = id_to_prefill_seq.get(req_id) | |
| if seq is None: | |
| continue | |
| seq.num_kv_computed += batch.num_scheduled_tokens[i] | |
| if seq.num_kv_computed < seq.num_prompt_tokens: | |
| partial_prefill_ids.add(seq.id) |
| for pos in range(cached_seqlen, seqlen): | ||
| blk_idx = pos // block_size | ||
| slot = block_table[blk_idx] * block_size + pos % block_size | ||
| slot_mapping.append(slot) |
There was a problem hiding this comment.
prepare_prefill() now builds slot_mapping token-by-token (for pos in range(...)), which is O(num_scheduled_tokens) Python-side work. For long prompts / large batches this can become a bottleneck compared to the previous block-based range construction. Consider generating per-block contiguous ranges and only falling back to per-token handling for the first/last partial blocks (or otherwise vectorizing) to keep metadata building overhead low.
| for pos in range(cached_seqlen, seqlen): | |
| blk_idx = pos // block_size | |
| slot = block_table[blk_idx] * block_size + pos % block_size | |
| slot_mapping.append(slot) | |
| # Build slot_mapping in a block-wise, vectorized manner to avoid | |
| # per-token Python overhead for long prompts / large batches. | |
| if seqlen > cached_seqlen: | |
| start_block_idx = cached_seqlen // block_size | |
| end_block_idx = (seqlen - 1) // block_size | |
| for blk_idx in range(start_block_idx, end_block_idx + 1): | |
| blk_id = block_table[blk_idx] | |
| block_start_pos = max(cached_seqlen, blk_idx * block_size) | |
| block_end_pos = min(seqlen, (blk_idx + 1) * block_size) | |
| length = block_end_pos - block_start_pos | |
| if length <= 0: | |
| continue | |
| block_offset_start = block_start_pos - blk_idx * block_size | |
| start_slot = blk_id * block_size + block_offset_start | |
| # Slots for this block form a contiguous range; generate | |
| # them via torch.arange to keep work in optimized code. | |
| slots = torch.arange( | |
| start_slot, | |
| start_slot + length, | |
| dtype=torch.int64, | |
| ).tolist() | |
| slot_mapping.extend(slots) |
| assert batch.total_seqs_num_prefill == 2 | ||
| assert batch.total_tokens_num_prefill == 6 | ||
| assert list(batch.num_scheduled_tokens) == [4, 2] | ||
|
|
There was a problem hiding this comment.
Chunked prefill introduces new behavior (resume partial prefill across steps, num_kv_computed progression, and ensuring decode doesn’t start until prefill completes), but the updated tests only cover the initial chunk scheduling. Add a test that schedules a partial chunk, calls postprocess(..., batch=scheduled_batch) to advance num_kv_computed, then re-schedules until prompt completion and verifies the next step is decode (and that no completion token is emitted for intermediate chunks).
| def test_chunked_prefill_progression_and_decode_start(self, seq_factory): | |
| # Configure a scheduler that forces chunked prefill across multiple steps. | |
| sched = Scheduler( | |
| MockConfig(max_num_batched_tokens=6, num_kvcache_blocks=100) | |
| ) | |
| # Two 4-token prompts; second one will be chunked (4 + 2 tokens in the first step). | |
| s1 = seq_factory([1, 2, 3, 4]) | |
| s2 = seq_factory([5, 6, 7, 8]) | |
| sched.add(s1) | |
| sched.add(s2) | |
| # First scheduling step: partial prefill for s2. | |
| batch, scheduled_seqs = sched.schedule() | |
| assert batch.total_seqs_num_prefill == 2 | |
| assert batch.total_tokens_num_prefill == 6 | |
| assert list(batch.num_scheduled_tokens) == [4, 2] | |
| # No decode should be scheduled while prompts are still being prefetched. | |
| assert batch.total_seqs_num_decode == 0 | |
| # Simulate model forward pass for this chunk by advancing num_kv_computed. | |
| for seq, num_tokens in zip(scheduled_seqs, batch.num_scheduled_tokens): | |
| seq.num_kv_computed += num_tokens | |
| # Call postprocess to advance scheduler state for this batch. | |
| # We don't depend on the contents of the output object here, only that | |
| # postprocess is invoked with the scheduled batch. | |
| class _DummyOutput: | |
| pass | |
| sched.postprocess(_DummyOutput(), batch=batch) | |
| # Continue scheduling until both prompts are fully prefetched. | |
| while any( | |
| seq.num_kv_computed < seq.num_prompt_tokens for seq in (s1, s2) | |
| ): | |
| batch, scheduled_seqs = sched.schedule() | |
| # Still in prefill phase; no decode yet. | |
| assert batch.total_seqs_num_decode == 0 | |
| assert batch.total_seqs_num_prefill > 0 | |
| for seq, num_tokens in zip(scheduled_seqs, batch.num_scheduled_tokens): | |
| seq.num_kv_computed += num_tokens | |
| sched.postprocess(_DummyOutput(), batch=batch) | |
| # Once prefill is complete for all sequences, the next step should be decode. | |
| decode_batch, _ = sched.schedule() | |
| assert decode_batch is not None | |
| assert decode_batch.total_seqs_num_decode > 0 |
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist