Skip to content

feat: PD disaggregation#253

Open
inkcherry wants to merge 27 commits intoROCm:mainfrom
inkcherry:atom_pd
Open

feat: PD disaggregation#253
inkcherry wants to merge 27 commits intoROCm:mainfrom
inkcherry:atom_pd

Conversation

@inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Mar 2, 2026

Key scheduling changes:

  • New sequence status WAITING_FOR_REMOTE_KVS: sequences that have been allocated blocks but are waiting for remote KV data to arrive.
  • Deferred block deallocation: On the producer (prefill) side, finished sequences hold their blocks until the consumer confirms transfer completion, preventing use-after-free.
  • Connector metadata per batch: Each ScheduledBatch carries connector_meta_output which is dispatched to workers to trigger async KV loading before the forward pass.
  • First-decode-after-remote-prefill: Sequences transitioning from WAITING_FOR_REMOTE_KVS to RUNNING skip local block re-allocation (blocks were pre-allocated during the async load phase) and enter decode directly.

Model Runner & Worker Pipeline

  • Dedicated KV output channel: AsyncIOProcManager adds a separate per-worker ZMQ channel for KV transfer status, preventing it from mixing with regular forward outputs.
  • Cross-TP aggregation: KVOutputAggregator collects finished_sending / finished_recving signals from all TP ranks before reporting completion to the scheduler.
  • Two new worker RPCs: process_kvconnector_output dispatches connector metadata to trigger RDMA loads; async_proc_aggregation collects transfer completion status. Both are called every engine step alongside the forward pass.

Each engine step now follows an extended pipeline:

  1. Schedule — select sequences, build batch with connector metadata
  2. Dispatch KV metadata — broadcast connector meta to all TP workers (triggers async RDMA load)
  3. Forward pass — run model on the scheduled batch
  4. Aggregate KV status — collect transfer completion from all workers
  5. Update scheduler — mark finished transfers, free deferred blocks
  6. Postprocess — emit tokens, handle stop conditions, notify connector on request completion

Front API Extensions

Proxy-based Request Orchestration

kv_transfer_engine

session lifecycle, handshake protocol, block-granularity KV read, async completion tracking, mori-io backend integration

Other Related Modifications

Since the seq_id is defined based on the processing progress, the prefill seq_id is used as the transfer_id to unify block management and release.

AiterMLAMetadataBuilder decode metadata preparation is updated to handle the first-decode-after-remote-prefill case:


thanks for the help of @ZhangLirong-amd

Test Plan

accuracy:- [x] GSM8K 5-shot accuracy benchmark (1P1D, DeepSeek-V3, FP8, TP=8, concurrent=64)

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9609 ± 0.0121
strict-match 5 exact_match 0.9570 ± 0.0127

performance:- [x] 1P2D, use de5012b13a33596865f3c3c3dd1e31fadaaaccd8 code base

Model: DeepSeek-V3 | 1P2D = 1 Prefill + 2 Decode nodes, all TP8 on MI300

TPOT ≤ 30 ms

Config Single Node (tokens/s) 1P2D Total (tokens/s) 1P2D Per Node (tokens/s)
ISL 1K / OSL 512 1155.3 3533.4 1177.8
ISL 1K / OSL 1K 1246.3 3883.8 1294.6
ISL 2K / OSL 1K 685.9 2510.9 837.0

TPOT ≤ 40 ms

Config Single Node (tokens/s) 1P2D Total (tokens/s) 1P2D Per Node (tokens/s)
ISL 1K / OSL 512 1646.3 5385.7 1795.2
ISL 1K / OSL 1K 1823.2 5829.3 1943.1
ISL 2K / OSL 1K 1667.8 5113.4 1704.5

inkcherry and others added 9 commits March 2, 2026 12:45
* clear schedule redundant variables

* async basic

* update

* update util

* update multi work rpc

* update

* add connector base

* add meta info

* update, finished return correct response

* request output stage1

* request output stage2 , finished

* finished ping proxy draft

* add proxy

* atom config to connector

* proxy pass

* update start loadkv

* handshake

* sess none

* fix and add sleep

* update ,transfer without crash, need to *16, need to do get_finished&scheduler

* add finished , but 3 step run will met none

* try to add decode only

* run but hang in single test(no toproxy)

* fix hang

* sperate process kv output

* pass test , need to do decode onlyt

* update, meet prepare input bu

* run success but accuracy false

* update , seems 1 seq random correct?

* update

* disable deffered out

* run gsmk8 con=64 without crash(ds5layer)

* update 0116 status to feiyue

* update

* fix proxy

* update

* update, prefill instance can do full task

* fix high con random acc issue

* remove debug

* update

* fix bench perf

* update

* update

* gsm 0.92

* update

* fix hang

* refactor: remove redundant comments and re-adjust code organization

* fix refactor

* use transfer id (ROCm#6)

* update test

* use transfer_id

* update

* use transfer_id to fix prefill decode mismatch

* Use transfer (ROCm#7)

* update test

* use transfer_id

* update

* use transfer_id to fix prefill decode mismatch

* update assert

* fix merge ,runwithout crash

* run without crash ,but meet acc issue

* update

* fix

* enable aiter log

---------

Co-authored-by: root <root@useocpm2m-097-088.amd.com>
Co-authored-by: root <root@useocpm2m-097-083.amd.com>
Co-authored-by: knitcapcat <zejwang@amd.com>
Copilot AI review requested due to automatic review settings March 2, 2026 11:38
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Prefill/Decode (P/D) KV-cache disaggregation support to ATOM, integrating a KV transfer connector (RDMA via MORI-IO), a routing proxy, and end-to-end plumbing through scheduling, runner execution, and the OpenAI-compatible server.

Changes:

  • Introduces atom/disaggregation/ (KV transfer engine, KV output aggregation, proxy service) plus documentation.
  • Wires KV-transfer state through engine components (scheduler/sequence/model_runner/async worker manager) and exposes transfer metadata via the OpenAI API.
  • Adds configuration/CLI support for kv_transfer_config and a small unit test for KV output aggregation.

Reviewed changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
test/disaggregation/test_kv_aggregator.py Adds a basic test for cross-worker KV transfer completion aggregation.
atom/utils/network.py Adds IP detection helper and env-var handling for host IP selection.
atom/utils/forward_context.py Adds lazy global KV connector instances and registers KV caches into the connector.
atom/model_ops/attentions/aiter_mla.py Adjusts decode metadata preparation for first-decode-after-remote-prefill behavior.
atom/model_engine/sequence.py Adds new sequence status and KV-transfer-related fields.
atom/model_engine/scheduler.py Integrates KV disaggregation into scheduling/postprocess and carries connector metadata in batches.
atom/model_engine/request.py Extends streaming output object to include KV transfer output metadata.
atom/model_engine/model_runner.py Adds worker-side dispatch of connector metadata and worker KV-finish reporting.
atom/model_engine/llm_engine.py Parses KV transfer config and forwards per-request KV transfer params into sequences.
atom/model_engine/engine_core.py Dispatches connector metadata to workers and aggregates KV completion across TP ranks each step.
atom/model_engine/block_manager.py Introduces a constant intended for remote/prefill hashing coordination.
atom/model_engine/async_proc.py Adds a dedicated KV output channel and aggregated RPC support across workers.
atom/model_engine/arg_utils.py Adds CLI argument and default for --kv-transfer-config.
atom/entrypoints/openai_server.py Accepts KV transfer params on requests and returns KV transfer metadata in responses/streaming chunks.
atom/disaggregation/requirements.txt Adds disaggregation-specific Python dependencies list.
atom/disaggregation/proxy.py Adds Quart-based proxy for service discovery + routing prefill/decode requests.
atom/disaggregation/kvoutput_aggregator.py Adds aggregator to combine per-TP-worker finished send/recv signals.
atom/disaggregation/kv_transfer_engine.py Adds MORI-IO based KV cache transfer connector (worker + scheduler components).
atom/disaggregation/__init__.py Initializes the disaggregation package.
atom/disaggregation/README.md Documents how to run proxy/prefill/decode for P/D disaggregation.
atom/config.py Adds kv_transfer_config field and parsing for string-based configs.
README.md Adds top-level note announcing P/D disaggregation support and doc link.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

)
stream_gen = stream_decode_response(session, decode_response, request_id)
response = await make_response(stream_gen)
response.headers["Content-Type"] = "application/json; charset=utf-8"
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proxy streams bytes from the decode instance (which are SSE-formatted in the OpenAI streaming API), but forces Content-Type: application/json; charset=utf-8. This can break streaming clients expecting text/event-stream. Consider preserving the upstream content-type or explicitly setting text/event-stream for streaming responses.

Suggested change
response.headers["Content-Type"] = "application/json; charset=utf-8"
response.headers["Content-Type"] = "text/event-stream; charset=utf-8"

Copilot uses AI. Check for mistakes.
Comment on lines +301 to +304

num_blocks_per_seq_bk = [
(ctx + self.block_size - 1) // self.block_size for ctx in batch.context_lens
]
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_blocks_per_seq_bk is computed but never used. Remove it (or use it) to avoid dead code and keep the decode metadata preparation easier to follow.

Suggested change
num_blocks_per_seq_bk = [
(ctx + self.block_size - 1) // self.block_size for ctx in batch.context_lens
]

Copilot uses AI. Check for mistakes.
Comment on lines +173 to 181
self.is_first_decode_without_local_prefill = [
seq.is_first_decode for seq in seqs.values()
]
self.temperatures = [seq.temperature for seq in seqs.values()]
self.context_lens = [seq.num_tokens for seq in seqs.values()]

# Build the flat scheduled-token array
offs = self.context_lens - self.num_rejected - self.num_scheduled_tokens
self.scheduled_tokens = np.empty(total_tokens_num, dtype=np.int32)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScheduledBatch.__init__ sets self.context_lens to a Python list, but then computes offs = self.context_lens - self.num_rejected - self.num_scheduled_tokens, which will raise TypeError (list minus ndarray). Keep context_lens as a NumPy array (as before) or explicitly convert before doing vectorized arithmetic so batch construction works.

Copilot uses AI. Check for mistakes.
Comment on lines 423 to 426
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq)

num_new_tokens = self.mtp_k + 1
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decode scheduling calls self.block_manager.may_append(seq) without passing num_new_tokens (it defaults to 1), even though each decode step schedules num_new_tokens = self.mtp_k + 1 tokens. This can under-allocate KV blocks when mtp_k > 0 (or near a block boundary) and lead to out-of-bounds writes later. Pass the correct num_new_tokens to may_append (and ensure the first-decode path still guarantees enough blocks).

Suggested change
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq)
num_new_tokens = self.mtp_k + 1
num_new_tokens = self.mtp_k + 1
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq, num_new_tokens=num_new_tokens)

Copilot uses AI. Check for mistakes.
@torch.inference_mode()
def async_proc_aggregation(self) -> KVConnectorOutput:
"""Collect finished send/recv status from the KV connector."""
done_sending, done_recving = get_kvconnector().get_finished()
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async_proc_aggregation unconditionally does get_kvconnector().get_finished(). When kv_transfer_config is not enabled (or the connector failed to initialize), get_kvconnector() can return None, causing an AttributeError. Guard for None and return an empty KVConnectorOutput so EngineCore can call aggregation safely even in non-disaggregation runs.

Suggested change
done_sending, done_recving = get_kvconnector().get_finished()
connector = get_kvconnector()
if connector is None:
# No KV connector available (e.g., non-disaggregation run); return an "empty" result.
return KVConnectorOutput(finished_sending=False, finished_recving=False)
done_sending, done_recving = connector.get_finished()

Copilot uses AI. Check for mistakes.
)
self.input_thread.start()

self.kv_aggregator = KVOutputAggregator(world_size=config.tensor_parallel_size)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EngineCore.__init__ assigns self.kv_aggregator = KVOutputAggregator(...) but this field is never referenced (aggregation is done via runner_mgr.call_func_with_aggregation). Remove the unused instance variable or switch the code to use it to avoid dead state.

Copilot uses AI. Check for mistakes.
Comment on lines 137 to 147
@@ -103,6 +147,7 @@ def send_output_to_socket(self, addr: str, output_queue: queue.Queue):
socket.send(serialized_obj)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

send_output_to_socket loops with while True: and blocks on output_queue.get() without checking still_running or consuming a sentinel, so AsyncIOProc.exit() cannot reliably stop/join these threads (especially now that an extra KV-output thread is created). Add a shutdown condition (e.g., check still_running and/or send a sentinel through the queue) so worker processes can exit cleanly.

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +3
import os
import socket
import warnings
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New module atom/utils/network.py is missing the SPDX license header used throughout atom/utils/* (e.g., atom/utils/envs.py). Add the standard SPDX + copyright header for consistency and licensing clarity.

Copilot uses AI. Check for mistakes.
Comment on lines +388 to +414
if not (hasattr(config, "kv_transfer_config") and config.kv_transfer_config):
return _global_kvconnector

if role == "worker":
from aiter.dist.parallel_state import get_tp_group

try:
tp_rank = get_tp_group().rank_in_group
except Exception:
_logger.warning(
"get_tp_group() failed (dist not initialized?), returning None"
)
return None

if _global_kvconnector is None:
from atom.disaggregation.kv_transfer_engine import KVConnector

_global_kvconnector = KVConnector(config)
_logger.debug("Initialized global KVConnector at tp_rank %d", tp_rank)

elif role == "scheduler":
from atom.disaggregation.kv_transfer_engine import KVConnectorScheduler

_global_kvconnector_scheduler = KVConnectorScheduler(config)
_logger.debug("Initialized global KVConnectorScheduler")
return _global_kvconnector_scheduler

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_kvconnector returns _global_kvconnector when KV transfer is not configured, even for role="scheduler". This makes the return value role-inconsistent and can leak a stale worker connector; callers also can’t distinguish “disabled” from “initialized earlier”. Return None when not configured, and return the role-appropriate global (_global_kvconnector_scheduler vs _global_kvconnector).

Copilot uses AI. Check for mistakes.
)

# --- Decode request ---
req_data["max_tokens"] -= 1
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proxy assumes req_data["max_tokens"] exists and is an int (req_data["max_tokens"] -= 1). Since this proxy accepts raw JSON (no schema validation), requests without max_tokens (or with None) will raise. Use a safe default / validation and handle max_completion_tokens as well if you intend to support both OpenAI-style fields.

Suggested change
req_data["max_tokens"] -= 1
max_tokens = req_data.get("max_tokens")
max_completion_tokens = req_data.get("max_completion_tokens")
effective_max_tokens = None
if isinstance(max_tokens, int):
effective_max_tokens = max_tokens
elif isinstance(max_completion_tokens, int):
effective_max_tokens = max_completion_tokens
if effective_max_tokens is not None:
# Reserve one token for prefill and ensure non-negative.
decode_max_tokens = max(0, effective_max_tokens - 1)
req_data["max_tokens"] = decode_max_tokens
if isinstance(max_completion_tokens, int) and not isinstance(max_tokens, int):
# Keep max_completion_tokens in sync if it was the source.
req_data["max_completion_tokens"] = decode_max_tokens

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 22, 2026 08:45
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 22 out of 23 changed files in this pull request and generated 22 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

while True:
try:
identity, msg = sock.recv_multipart()
self._handle_message(msg)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async_wait_reqid() calls self._handle_message(msg), but there is no _handle_message method defined on MoRIIOWrapper (only _dispatch_message). If this code path is exercised it will raise AttributeError and break notifications. Rename the call to the correct handler or implement _handle_message.

Suggested change
self._handle_message(msg)
self._dispatch_message(msg)

Copilot uses AI. Check for mistakes.
Comment on lines +439 to +451
def write_remote_data_single(
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
):
assert self.local_memory_registered, "You have not register local memory data!"
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
transfer_status = self.sessions[sess_idx].write(
local_offset,
remote_offset,
transfer_size_byte,
self.moriio_engine.allocate_transfer_uid(),
)
with self.lock:
self.transfer_status.append(transfer_status)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write_remote_data_single() references self.sessions[sess_idx], but self.sessions is never initialized in MoRIIOWrapper, so this will raise AttributeError if called. Either initialize self.sessions (and document its lifecycle) or remove/replace this method if it’s not used.

Copilot uses AI. Check for mistakes.
"""Collect finished send/recv status from the KV connector."""
connector = get_kvconnector()
if connector is None:
return KVConnectorOutput(finished_sending=[], finished_recving=[])
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When connector is None, async_proc_aggregation() returns KVConnectorOutput(finished_sending=[], finished_recving=[]). The KVConnectorOutput dataclass is defined with set[...] fields, so returning lists here is inconsistent and can lead to surprises if callers assume set semantics. Prefer returning empty sets (or relying on defaults) for both fields.

Suggested change
return KVConnectorOutput(finished_sending=[], finished_recving=[])
return KVConnectorOutput(finished_sending=set(), finished_recving=set())

Copilot uses AI. Check for mistakes.
Comment on lines +16 to +46
# Phase 2: All workers have finished sending req_id="1"
worker_outputs = [
KVConnectorOutput(finished_sending={"1"}) for _ in range(8)
]
finished = aggregator.aggregate(worker_outputs)
print(f"Round 2 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}")
assert finished.finished_sending == {"1"}
assert finished.finished_recving == set()

# Phase 3: All workers have finished receiving req_id="1"
worker_outputs = [
KVConnectorOutput(finished_recving={"1"}) for _ in range(8)
]
finished = aggregator.aggregate(worker_outputs)
print(f"Round 3 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}")
assert finished.finished_sending == set()
assert finished.finished_recving == {"1"}

print("\nTest 2: Complex scenario")
aggregator2 = KVOutputAggregator(world_size=8)

# First round: Some workers complete some requests
worker_outputs = [
KVConnectorOutput(finished_sending={"1", "2"}, finished_recving={"1"}),
KVConnectorOutput(finished_sending={"1"}),
KVConnectorOutput(finished_sending={"2"}, finished_recving={"2"}),
KVConnectorOutput(finished_sending={"1"}),
KVConnectorOutput(), # Empty output
KVConnectorOutput(finished_sending={"1"}, finished_recving={"3"}),
KVConnectorOutput(finished_sending={"2"}),
KVConnectorOutput(finished_sending={"1"})
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test uses string request IDs (e.g. "1") for finished_sending/finished_recving, but the engine appears to use integer seq.id/transfer_id values for KV completion tracking. Using the same ID type in tests as production (likely int) would better catch type-mismatch bugs where completions never match scheduler sequence IDs.

Suggested change
# Phase 2: All workers have finished sending req_id="1"
worker_outputs = [
KVConnectorOutput(finished_sending={"1"}) for _ in range(8)
]
finished = aggregator.aggregate(worker_outputs)
print(f"Round 2 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}")
assert finished.finished_sending == {"1"}
assert finished.finished_recving == set()
# Phase 3: All workers have finished receiving req_id="1"
worker_outputs = [
KVConnectorOutput(finished_recving={"1"}) for _ in range(8)
]
finished = aggregator.aggregate(worker_outputs)
print(f"Round 3 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}")
assert finished.finished_sending == set()
assert finished.finished_recving == {"1"}
print("\nTest 2: Complex scenario")
aggregator2 = KVOutputAggregator(world_size=8)
# First round: Some workers complete some requests
worker_outputs = [
KVConnectorOutput(finished_sending={"1", "2"}, finished_recving={"1"}),
KVConnectorOutput(finished_sending={"1"}),
KVConnectorOutput(finished_sending={"2"}, finished_recving={"2"}),
KVConnectorOutput(finished_sending={"1"}),
KVConnectorOutput(), # Empty output
KVConnectorOutput(finished_sending={"1"}, finished_recving={"3"}),
KVConnectorOutput(finished_sending={"2"}),
KVConnectorOutput(finished_sending={"1"})
# Phase 2: All workers have finished sending req_id=1
worker_outputs = [
KVConnectorOutput(finished_sending={1}) for _ in range(8)
]
finished = aggregator.aggregate(worker_outputs)
print(f"Round 2 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}")
assert finished.finished_sending == {1}
assert finished.finished_recving == set()
# Phase 3: All workers have finished receiving req_id=1
worker_outputs = [
KVConnectorOutput(finished_recving={1}) for _ in range(8)
]
finished = aggregator.aggregate(worker_outputs)
print(f"Round 3 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}")
assert finished.finished_sending == set()
assert finished.finished_recving == {1}
print("\nTest 2: Complex scenario")
aggregator2 = KVOutputAggregator(world_size=8)
# First round: Some workers complete some requests
worker_outputs = [
KVConnectorOutput(finished_sending={1, 2}, finished_recving={1}),
KVConnectorOutput(finished_sending={1}),
KVConnectorOutput(finished_sending={2}, finished_recving={2}),
KVConnectorOutput(finished_sending={1}),
KVConnectorOutput(), # Empty output
KVConnectorOutput(finished_sending={1}, finished_recving={3}),
KVConnectorOutput(finished_sending={2}),
KVConnectorOutput(finished_sending={1})

Copilot uses AI. Check for mistakes.
):
seq = self.waiting[0]
# --- Prefill scheduling ---
while self.waiting and num_seqs_prefill < self.max_num_seqs:
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefill scheduling loop no longer applies the delay_factor batching logic (_passed_delay()), which changes scheduling behavior and can reduce batching efficiency under load. If this is intentional for disaggregation, consider gating it only for the remote-KV path; otherwise reintroduce the delay condition to preserve existing throughput/latency tradeoffs.

Suggested change
while self.waiting and num_seqs_prefill < self.max_num_seqs:
while self.waiting and num_seqs_prefill < self.max_num_seqs:
# For non-disaggregated setups, respect delay_factor-based batching
# to preserve existing throughput/latency tradeoffs. The delay is
# disabled when using a KV connector, since disaggregation already
# governs prefill timing via remote KV availability.
if self.kv_connector is None and not self._passed_delay():
break

Copilot uses AI. Check for mistakes.
Comment on lines +29 to +44
@dataclass
class KVConnectorOutput:
"""Per-worker snapshot of finished KV cache transfers.

Each TP worker produces one of these per scheduler step. The
:class:`KVOutputAggregator` combines them to determine which
request IDs have finished on *all* workers.

Attributes:
finished_sending: Request IDs whose KV send completed on this worker.
finished_recving: Request IDs whose KV receive completed on this worker.
"""

finished_sending: set[str] = field(default_factory=set)
finished_recving: set[str] = field(default_factory=set)

Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KVConnectorOutput uses set[str] for request IDs, but the rest of the engine uses integer sequence/request IDs. Standardizing this type (e.g., set[int] or a ReqId = int alias) would prevent hard-to-debug mismatches where finished IDs never match scheduler sequence IDs.

Copilot uses AI. Check for mistakes.
Comment on lines 621 to 626
world_size: int = field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
data_parallel_master_port: int = 29500
data_parallel_master_port: int = field(default_factory=get_open_port)
"""Port of the data parallel master."""

data_parallel_base_port: int = get_open_port()
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_parallel_master_port now defaults to get_open_port(). If each DP rank constructs its own Config independently (common in multi-process launchers), they may pick different master ports and fail to rendezvous. The previous fixed default (29500) avoided this class of mismatch. If the goal is to avoid port collisions, consider deriving a single deterministic port (e.g., from an env var like ATOM_DP_MASTER_PORT, or selecting once on rank 0 and broadcasting) rather than calling get_open_port() during dataclass construction on every rank.

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +3
msgspec
msgpack
quart No newline at end of file
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These dependencies are required by new runtime code (kv_transfer_engine.py imports msgspec/msgpack, proxy.py imports quart), but they are not included in the package install deps (pyproject.toml [project].dependencies). As-is, enabling disaggregation will likely fail with ModuleNotFoundError unless users install this extra file manually. Consider adding them to pyproject.toml (possibly as an optional extra) and referencing that in docs.

Copilot uses AI. Check for mistakes.
README.md Outdated

## 📢 News

- **[2026/03]** ATOM now supports **Prefill/Decode (P/D) disaggregation** — run prefill and decode on separate GPU nodes with RDMA-based KV cache transfer via [MORI-IO](https://github.com/ROCm/mori). See [disaggregation docs](atom/disaggregation/README.md).
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The link path appears incorrect: this PR adds disaggregation docs under atom/mesh/disaggregation/README.md, but the README links to atom/disaggregation/README.md. Update the link to the actual location so it doesn't 404.

Suggested change
- **[2026/03]** ATOM now supports **Prefill/Decode (P/D) disaggregation** — run prefill and decode on separate GPU nodes with RDMA-based KV cache transfer via [MORI-IO](https://github.com/ROCm/mori). See [disaggregation docs](atom/disaggregation/README.md).
- **[2026/03]** ATOM now supports **Prefill/Decode (P/D) disaggregation** — run prefill and decode on separate GPU nodes with RDMA-based KV cache transfer via [MORI-IO](https://github.com/ROCm/mori). See [disaggregation docs](atom/mesh/disaggregation/README.md).

Copilot uses AI. Check for mistakes.
Comment on lines +373 to +387

num_blocks_per_seq_bk = [
(ctx + self.block_size - 1) // self.block_size for ctx in batch.context_lens
]
num_blocks_per_seq = []
for i, (ctx, is_first) in enumerate(zip(batch.context_lens, batch.is_first_decode_without_local_prefill)):
if is_first:
# First decode after remote prefill: use pre-allocated block count
blocks = len(batch.block_tables[i])
else:
# Normal case: ceil-divide context length by block size
blocks = (ctx + self.block_size - 1) // self.block_size
num_blocks_per_seq.append(blocks)
sum_blocks_before_converted = sum([(i + self.block_ratio - 1) // self.block_ratio for i in num_blocks_per_seq])

Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_blocks_per_seq is computed from batch.context_lens, but earlier in this method context_lens is adjusted (e.g., subtracting num_rejected for speculative decode) and block_tables may be truncated accordingly. Using the unadjusted batch.context_lens can make kv_indptr inconsistent with the actual block tables used for slot mapping, potentially breaking decode with speculative rejection. Prefer deriving block counts from the adjusted context_lens/block_tables used above.

Suggested change
num_blocks_per_seq_bk = [
(ctx + self.block_size - 1) // self.block_size for ctx in batch.context_lens
]
num_blocks_per_seq = []
for i, (ctx, is_first) in enumerate(zip(batch.context_lens, batch.is_first_decode_without_local_prefill)):
if is_first:
# First decode after remote prefill: use pre-allocated block count
blocks = len(batch.block_tables[i])
else:
# Normal case: ceil-divide context length by block size
blocks = (ctx + self.block_size - 1) // self.block_size
num_blocks_per_seq.append(blocks)
sum_blocks_before_converted = sum([(i + self.block_ratio - 1) // self.block_ratio for i in num_blocks_per_seq])
# Derive the number of blocks per sequence from the prepared block tables
# to ensure consistency with any adjustments (e.g., speculative rejection).
num_blocks_per_seq = [len(bt) for bt in batch.block_tables]
sum_blocks_before_converted = sum(
[(i + self.block_ratio - 1) // self.block_ratio for i in num_blocks_per_seq]
)

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 22, 2026 11:06
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 30 out of 31 changed files in this pull request and generated 9 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +260 to +270
if not prefill_instances or not decode_instances:
return await make_response((
"Service Unavailable: no prefill or decode instances registered.",
503,
))

# Round-robin instance selection
pid = request_nums % len(prefill_instances)
did = request_nums % len(decode_instances)
prefill_ep = prefill_instances[pid]
decode_ep = decode_instances[did]
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle_request() reads prefill_instances / decode_instances without _list_lock, but those lists are mutated concurrently by the registration listener thread. This can lead to inconsistent len() / indexing (e.g., IndexError), especially around pid/did selection. Wrap the availability check + endpoint selection in with _list_lock: (or copy the lists under lock before using them).

Suggested change
if not prefill_instances or not decode_instances:
return await make_response((
"Service Unavailable: no prefill or decode instances registered.",
503,
))
# Round-robin instance selection
pid = request_nums % len(prefill_instances)
did = request_nums % len(decode_instances)
prefill_ep = prefill_instances[pid]
decode_ep = decode_instances[did]
# Protect access to instance lists with the shared lock to avoid
# races with the registration listener thread.
with _list_lock:
if not prefill_instances or not decode_instances:
return await make_response((
"Service Unavailable: no prefill or decode instances registered.",
503,
))
# Round-robin instance selection
pid = request_nums % len(prefill_instances)
did = request_nums % len(decode_instances)
prefill_ep = prefill_instances[pid]
decode_ep = decode_instances[did]

Copilot uses AI. Check for mistakes.
Comment on lines +137 to +145
python -m atom.entrypoints.openai_server \
--kv_cache_dtype fp8 \
--model /path/to/model \
--block-size 16 \
-tp 8 \
--enable-dp-attention \
--enable-expert-parallel \
--kv-transfer-config '{"kv_role":"kv_consumer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}'
```
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This command example uses "http_prt" in --kv-transfer-config, but the code reads http_port. With the current spelling, the HTTP port override won’t take effect. Replace http_prt with http_port here.

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +6
import os
import socket
import warnings

def get_ip() -> str:
# Check environment variable first
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atom/utils/network.py is missing the SPDX license/copyright header that appears to be standard across the codebase (e.g., atom/utils/envs.py:1-2, atom/utils/forward_context.py:1-3). Add the SPDX header here for consistency and automated license scanning.

Copilot uses AI. Check for mistakes.
Comment on lines +399 to +401
if not (hasattr(config, "kv_transfer_config") and config.kv_transfer_config):
return _global_kvconnector

Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_kvconnector() returns _global_kvconnector whenever config.kv_transfer_config is missing/empty, regardless of the requested role. This can return the worker connector when the caller asked for the scheduler connector (or vice versa) if one was previously initialized. Consider returning _global_kvconnector_scheduler when role == "scheduler" (and _global_kvconnector for "worker"), even in the early-return path.

Copilot uses AI. Check for mistakes.
],
record_shapes=enable_detailed_profiling,
with_stack=enable_detailed_profiling,
with_stack=True,
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_profiler()’s docstring says with_stack is controlled by ATOM_PROFILER_MORE, but the profiler is created with with_stack=True unconditionally. Capturing stacks is expensive and can significantly increase profiling overhead and trace size; set with_stack=enable_detailed_profiling (or update the docstring if always-on is intentional).

Suggested change
with_stack=True,
with_stack=enable_detailed_profiling,

Copilot uses AI. Check for mistakes.
Comment on lines +122 to +130
python -m atom.entrypoints.openai_server \
--kv_cache_dtype fp8 \
--model /path/to/model \
--block-size 16 \
-tp 8 \
--enable-dp-attention \
--enable-expert-parallel \
--kv-transfer-config '{"kv_role":"kv_producer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}'
```
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This command example uses "http_prt" in --kv-transfer-config, but the code reads http_port. With the current spelling, the HTTP port override won’t take effect. Replace http_prt with http_port here.

Copilot uses AI. Check for mistakes.
Comment on lines +716 to +731
# If a handshake was needed, spin until it completes then read.
while need_handshake:
if (
self._ready_requests.empty()
and remote_engine_id not in self.load_ready_flag
):
continue
elif (
not self._ready_requests.empty()
and remote_engine_id in self.load_ready_flag
):
self._issue_read_for_req(*self._ready_requests.get_nowait())
break
else:
break

Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_load_kv() busy-spins in while need_handshake: with a tight continue loop while waiting for handshake completion. This can peg a CPU core during handshake latency. Prefer a blocking wait (e.g., Queue.get() with timeout) or at least add a short sleep/backoff and a clear exit condition; also consider tracking per-engine futures/events instead of polling shared dicts.

Suggested change
# If a handshake was needed, spin until it completes then read.
while need_handshake:
if (
self._ready_requests.empty()
and remote_engine_id not in self.load_ready_flag
):
continue
elif (
not self._ready_requests.empty()
and remote_engine_id in self.load_ready_flag
):
self._issue_read_for_req(*self._ready_requests.get_nowait())
break
else:
break
# If a handshake was needed, wait until it completes then read.
if need_handshake:
handshake_timeout = 5.0 # seconds
start_time = time.time()
while True:
# Exit if handshake takes too long to avoid indefinite spinning.
if time.time() - start_time > handshake_timeout:
logger.warning(
"Timed out waiting for KV handshake with remote engine %s "
"after %.1f seconds",
remote_engine_id,
handshake_timeout,
)
break
if (
self._ready_requests.empty()
and remote_engine_id not in self.load_ready_flag
):
# Handshake not yet complete; back off briefly to avoid busy-spin.
time.sleep(0.001)
continue
elif (
not self._ready_requests.empty()
and remote_engine_id in self.load_ready_flag
):
self._issue_read_for_req(*self._ready_requests.get_nowait())
break
else:
# Either handshake completed but no matching ready request, or
# state changed unexpectedly; exit loop and let caller decide.
break

Copilot uses AI. Check for mistakes.
Comment on lines 66 to 68
logger.info(
f"LLMEngine init with {self.data_parallel_size} data parallel ranks"
)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two identical logger.info calls for the data-parallel rank count, which will duplicate logs on every engine init. Remove one of them to keep logs clean.

Suggested change
logger.info(
f"LLMEngine init with {self.data_parallel_size} data parallel ranks"
)

Copilot uses AI. Check for mistakes.
Comment on lines +438 to +447
# Skip block append for the first decode step after remote
# prefill — blocks were already allocated during prefill.
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq)

num_new_tokens = self.mtp_k + 1
self.block_manager.may_append(seq, num_new_tokens)
scheduled_seqs[seq.id] = seq
seq.type = SequenceType.DECODE
num_scheduled_tokens.append(num_new_tokens)

num_scheduled_tokens_np = num_scheduled_tokens
total_tokens_num_decode = sum(num_scheduled_tokens_np)

assert scheduled_seqs
self.running.extendleft(reversed(scheduled_seqs.values()))
# logger.info(
# f"Scheduled decode batch: {num_seqs_decode} reqs, {total_tokens_num_decode} tokens, req_ids: {tuple(scheduled_seqs.keys())}"
# )
return (
ScheduledBatch(
seqs=scheduled_seqs,
num_scheduled_tokens=num_scheduled_tokens_np,
total_tokens_num=total_tokens_num_decode,
total_tokens_num_decode=total_tokens_num_decode,
total_seqs_num=num_seqs_prefill + num_seqs_decode,
total_seqs_num_prefill=num_seqs_prefill,
total_seqs_num_decode=num_seqs_decode,
num_spec_step=self.mtp_k,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
),
scheduled_seqs,
seq.is_first_decode = False
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq.is_first_decode is being reset to False during decode scheduling before ScheduledBatch is constructed. Since ScheduledBatch.__init__ reads seq.is_first_decode to populate is_first_decode_without_local_prefill, the batch will never see True, breaking the “first-decode-after-remote-prefill” path (e.g., MLA decode metadata will compute block counts as if local prefill happened). Reset the flag only after the batch snapshot is created (or snapshot the flag into a local variable / batch field before mutating the sequence).

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 23, 2026 03:52
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 30 out of 31 changed files in this pull request and generated 15 comments.

Comments suppressed due to low confidence (2)

atom/model_engine/model_runner.py:15

  • torch.distributed as dist is imported but not used anywhere in this file. If linting is enabled this will fail CI; otherwise it’s dead code and should be removed.
import torch
import torch.profiler as torch_profiler
import tqdm

atom/model_engine/llm_engine.py:201

  • Sequence(...) is now constructed without passing num_draft_tokens=self.num_speculative_tokens and mamba_enabled=self.mamba_enabled. This will silently disable speculative decoding and mamba handling for all requests when KV disaggregation is enabled (and even when it isn’t, since the defaults are 0/False). Please pass the existing parameters through while adding kv_transfer_params.
        if self.config.hf_config.model_type == "qwen3_next":
            self.mamba_enabled = True

    def preprocess(
        self,
        prompt_or_tokens: str | list[int],
        sampling_params: SamplingParams,
        stream_callback=None,
        kv_transfer_params=None,
    ):
        """responsible for:
        1) Tokenize
        2) Create Sequence object"""
        tokens = (
            self.tokenizer.encode(prompt_or_tokens)
            if isinstance(prompt_or_tokens, str)
            else prompt_or_tokens
        )

        stop_token_sequences = []
        if sampling_params.stop_strings:
            stops = (
                [sampling_params.stop_strings]
                if isinstance(sampling_params.stop_strings, str)
                else sampling_params.stop_strings
            )
            for stop_str in stops:

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +422 to +427
from atom.mesh.disaggregation import KVConnectorFactory

_global_kvconnector_scheduler = KVConnectorFactory.create_connector(
config, role="scheduler"
)
_logger.debug("Initialized global KVConnectorScheduler")
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the role == "scheduler" branch, _global_kvconnector_scheduler is re-created on every call (no if _global_kvconnector_scheduler is None guard). That can leak resources / duplicate background threads/sockets if get_kvconnector("scheduler", ...) is called multiple times; make it truly lazy like the worker connector.

Suggested change
from atom.mesh.disaggregation import KVConnectorFactory
_global_kvconnector_scheduler = KVConnectorFactory.create_connector(
config, role="scheduler"
)
_logger.debug("Initialized global KVConnectorScheduler")
if _global_kvconnector_scheduler is None:
from atom.mesh.disaggregation import KVConnectorFactory
_global_kvconnector_scheduler = KVConnectorFactory.create_connector(
config, role="scheduler"
)
_logger.debug("Initialized global KVConnectorScheduler")

Copilot uses AI. Check for mistakes.
Comment on lines 444 to 449
# Skip block append for the first decode step after remote
# prefill — blocks were already allocated during prefill.
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq)

num_new_tokens = self.mtp_k + 1
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decode scheduling computes num_new_tokens = self.mtp_k + 1 but calls self.block_manager.may_append(seq) without passing num_new_tokens. For speculative decode (mtp_k>0) this can under-allocate KV blocks, since BlockManager.may_append uses num_new_tokens to decide how many blocks are needed.

Suggested change
# Skip block append for the first decode step after remote
# prefill — blocks were already allocated during prefill.
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq)
num_new_tokens = self.mtp_k + 1
num_new_tokens = self.mtp_k + 1
# Skip block append for the first decode step after remote
# prefill — blocks were already allocated during prefill.
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq, num_new_tokens=num_new_tokens)

Copilot uses AI. Check for mistakes.
Comment on lines +340 to +343
# --- Prefill scheduling ---
while self.waiting and num_seqs_prefill < self.max_num_seqs:
seq = self.waiting.popleft()

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefill loop no longer uses the delay_factor / _passed_delay() gating (but _passed_delay is still implemented). This is a behavior change that can regress batching efficiency and leaves dead code behind; either reintroduce the delay condition or remove the unused delay fields/method to keep scheduling logic consistent.

Copilot uses AI. Check for mistakes.
Comment on lines +700 to +712
remote_engine_id = f"{meta.remote_host}:{meta.remote_handshake_port}"
meta.remote_engine_id = remote_engine_id
dp0_id = self._engine_name_with_dp(remote_engine_id, 0)

if dp0_id not in self._remote_agents:
with self._handshake_lock:
if remote_engine_id not in self._remote_agents:
self._initiate_background_handshake(
req_id, remote_engine_id, meta
)
need_handshake = True
continue

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handshake cache lookup uses dp0_id = engine_dp0 but inside the lock checks if remote_engine_id not in self._remote_agents:. Since _remote_agents is populated with DP-suffixed keys (e.g. ..._dp0), this condition will stay true and can trigger redundant handshakes for every request. The inner check should likely use dp0_id (or otherwise normalize keys consistently).

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +47
EngineId = str
ReqId = str
TransferId = int

# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------


@dataclass
class KVConnectorOutput:
"""Per-worker snapshot of finished KV cache transfers.

Each TP worker produces one of these per scheduler step. The
:class:`KVOutputAggregator` combines them to determine which
request IDs have finished on *all* workers.

Attributes:
finished_sending: Request IDs whose KV send completed on this worker.
finished_recving: Request IDs whose KV receive completed on this worker.
expected_finished_count: How many finished notifications should be
expected per request (used by the aggregator).
"""

finished_sending: set[str] = field(default_factory=set)
finished_recving: set[str] = field(default_factory=set)
expected_finished_count: int = 0
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type aliases and KVConnectorOutput currently use ReqId = str / set[str], but the scheduler/engine paths (e.g., Sequence.id, deferred_free_blocks, and POP_DONE_RECV parsing) operate on integer request IDs. This mismatch will confuse API consumers and can break equality checks if any path starts emitting strings. Consider standardizing on int request IDs here (or consistently converting to str at boundaries).

Copilot uses AI. Check for mistakes.
Comment on lines 324 to 339
def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
# prefill
"""Select the next batch of sequences for a forward pass.

Tries prefill first; if no new prefills are ready, falls back to
decoding already-running sequences.
"""
scheduled_seqs = {}
num_seqs_prefill = 0
num_batched_tokens = 0

skipped_waiting_requests: deque[Sequence] = deque()
num_scheduled_tokens: list[int] = []
scheduled_spec_decode_tokens: dict[int, np.ndarray] = {}

if not self.running and not self.waiting:
# self.block_manager.reset()
return None

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schedule() is annotated as returning tuple[ScheduledBatch, dict[int, Sequence]] but it returns None when there are no requests. Please update the return type to Optional[...] (and any callers) to match the actual behavior and avoid type-checking/contract confusion.

Copilot uses AI. Check for mistakes.
Comment on lines +174 to +175
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aiohttp ClientTimeout(total=6 * 6000 * 6000) is ~2500 days, which is likely a unit bug and can leave hung requests/sockets around indefinitely under failure. Please replace with a reasonable bound (e.g. seconds/minutes) and/or make it configurable.

Suggested change
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
timeout_seconds = int(os.environ.get("DISAGG_PROXY_HTTP_TIMEOUT", "600"))
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout_seconds)

Copilot uses AI. Check for mistakes.
Comment on lines +1720 to +1723
connector.start_load_kv(connector_meta_output)

@torch.inference_mode()
def async_proc_aggregation(self) -> KVConnectorOutput:
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the connector is absent, async_proc_aggregation returns KVConnectorOutput(finished_sending=[], finished_recving=[]), but the dataclass fields are sets. Returning KVConnectorOutput() (or sets) avoids type inconsistency and prevents downstream code from accidentally treating these as sequences with duplicates.

Copilot uses AI. Check for mistakes.


def set_kv_cache_data(
kv_cache_data: dict[int, KVCacheTensor], config: Optional[Config] = None
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_kv_cache_data is typed as dict[int, KVCacheTensor], but ForwardContext.kv_cache_data is dict[str, KVCacheTensor] and the caller builds keys like "layer_{i}". Align the type annotation (and any downstream assumptions) to dict[str, KVCacheTensor] to avoid confusion and static type errors.

Suggested change
kv_cache_data: dict[int, KVCacheTensor], config: Optional[Config] = None
kv_cache_data: dict[str, KVCacheTensor], config: Optional[Config] = None

Copilot uses AI. Check for mistakes.
Comment on lines +58 to +65
import json

kv_config_str = kwargs.get("kv_transfer_config", "{}")
try:
config.kv_transfer_config = json.loads(kv_config_str)
logger.info(f"KV transfer config loaded: {config.kv_transfer_config}")
except json.JSONDecodeError:
config.kv_transfer_config = {}
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_transfer_config is already a Config field (and Config.__post_init__ parses string->dict). This extra json.loads(kwargs.get('kv_transfer_config', '{}')) is redundant and unsafe: if kv_transfer_config is already a dict it will raise TypeError (not caught), and it mutates config after CoreManager(config) has been constructed (so the engine core won’t see the updated value anyway). Consider removing this block and relying on Config to parse/hold KV config before CoreManager init.

Suggested change
import json
kv_config_str = kwargs.get("kv_transfer_config", "{}")
try:
config.kv_transfer_config = json.loads(kv_config_str)
logger.info(f"KV transfer config loaded: {config.kv_transfer_config}")
except json.JSONDecodeError:
config.kv_transfer_config = {}

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 23, 2026 13:59
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 31 out of 32 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 69 to 71
logger.info(
f"LLMEngine init with {self.data_parallel_size} data parallel ranks"
)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two identical logger.info("LLMEngine init with ...") calls back-to-back. This duplicates log spam; remove one of them.

Suggested change
logger.info(
f"LLMEngine init with {self.data_parallel_size} data parallel ranks"
)

Copilot uses AI. Check for mistakes.
}
# vllm use register_kv_caches to register kv_cache_data. We just set it to global here
set_kv_cache_data(kv_cache_data)
set_kv_cache_data(kv_cache_data, config)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_kv_cache_data(kv_cache_data, config) uses config which is not defined in this scope. This will raise NameError during KV cache allocation. Pass self.config (or the intended config object) instead.

Suggested change
set_kv_cache_data(kv_cache_data, config)
set_kv_cache_data(kv_cache_data, self.config)

Copilot uses AI. Check for mistakes.
Comment on lines +303 to +304
req_data["max_tokens"] -= 1
req_data["kv_transfer_params"] = {
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

req_data["max_tokens"] -= 1 will raise if the client sends null/omits max_tokens (it is optional in the OpenAI request model). Guard against None and validate that max_tokens >= 1 before decrementing, otherwise return a 400 with a clear message.

Copilot uses AI. Check for mistakes.
world_size: int = field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
data_parallel_master_port: int = 29500
data_parallel_master_port: int = field(default_factory=get_open_port)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParallelConfig.data_parallel_master_port now defaults to get_open_port(). Since each process typically constructs its own Config, different DP ranks can pick different master ports, breaking rendezvous for stateless_init_torch_distributed_process_group (all ranks must use the same host:port). Prefer a deterministic default (e.g. 29500) and/or require this value to be provided via env/args so it is consistent across ranks.

Suggested change
data_parallel_master_port: int = field(default_factory=get_open_port)
data_parallel_master_port: int = int(os.getenv("ATOM_DATA_PARALLEL_MASTER_PORT", "29500"))

Copilot uses AI. Check for mistakes.
Comment on lines 513 to 518
# Skip block append for the first decode step after remote
# prefill — blocks were already allocated during prefill.
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq)

num_new_tokens = self.mtp_k + 1
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In decode scheduling, BlockManager.may_append is called without num_new_tokens, so it defaults to 1. When self.mtp_k > 0 (or any case where more than 1 token is scheduled), this can under-allocate KV blocks and cause out-of-blocks or incorrect block tables. Pass the computed num_new_tokens into may_append (while still skipping the first-decode-after-remote-prefill case).

Suggested change
# Skip block append for the first decode step after remote
# prefill — blocks were already allocated during prefill.
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq)
num_new_tokens = self.mtp_k + 1
num_new_tokens = self.mtp_k + 1
# Skip block append for the first decode step after remote
# prefill — blocks were already allocated during prefill.
if not getattr(seq, "is_first_decode", False):
self.block_manager.may_append(seq, num_new_tokens)

Copilot uses AI. Check for mistakes.
Comment on lines +413 to +415
if not (hasattr(config, "kv_transfer_config") and config.kv_transfer_config):
return _global_kvconnector

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_kvconnector returns _global_kvconnector when config is missing/disabled regardless of role. For role="scheduler", this can return the worker connector (or None) even if a scheduler connector was previously initialized. Consider returning _global_kvconnector_scheduler when role == "scheduler" (and/or branching the early return by role).

Copilot uses AI. Check for mistakes.
Comment on lines +438 to +441
_global_kvconnector_scheduler = KVConnectorFactory.create_connector(
config, role="scheduler"
)
_logger.debug("Initialized global KVConnectorScheduler")
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scheduler connector branch recreates a new connector on every call (no _global_kvconnector_scheduler is None guard). This can leak resources/threads and lose internal state. Cache the scheduler connector similarly to the worker connector and only create it once per process.

Suggested change
_global_kvconnector_scheduler = KVConnectorFactory.create_connector(
config, role="scheduler"
)
_logger.debug("Initialized global KVConnectorScheduler")
if _global_kvconnector_scheduler is None:
_global_kvconnector_scheduler = KVConnectorFactory.create_connector(
config, role="scheduler"
)
_logger.debug("Initialized global KVConnectorScheduler")

Copilot uses AI. Check for mistakes.
if (
self._ready_requests.empty()
and remote_engine_id not in self.load_ready_flag
):
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The while need_handshake: loop busy-spins with continue when the handshake isn't ready yet, which can peg a CPU core. Add a small sleep/backoff (or use blocking Queue.get/Event) to avoid hot spinning while waiting for the handshake/ready queue.

Suggested change
):
):
time.sleep(0.001)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants