Conversation
* clear schedule redundant variables * async basic * update * update util * update multi work rpc * update * add connector base * add meta info * update, finished return correct response * request output stage1 * request output stage2 , finished * finished ping proxy draft * add proxy * atom config to connector * proxy pass * update start loadkv * handshake * sess none * fix and add sleep * update ,transfer without crash, need to *16, need to do get_finished&scheduler * add finished , but 3 step run will met none * try to add decode only * run but hang in single test(no toproxy) * fix hang * sperate process kv output * pass test , need to do decode onlyt * update, meet prepare input bu * run success but accuracy false * update , seems 1 seq random correct? * update * disable deffered out * run gsmk8 con=64 without crash(ds5layer) * update 0116 status to feiyue * update * fix proxy * update * update, prefill instance can do full task * fix high con random acc issue * remove debug * update * fix bench perf * update * update * gsm 0.92 * update * fix hang * refactor: remove redundant comments and re-adjust code organization * fix refactor * use transfer id (ROCm#6) * update test * use transfer_id * update * use transfer_id to fix prefill decode mismatch * Use transfer (ROCm#7) * update test * use transfer_id * update * use transfer_id to fix prefill decode mismatch * update assert * fix merge ,runwithout crash * run without crash ,but meet acc issue * update * fix * enable aiter log --------- Co-authored-by: root <root@useocpm2m-097-088.amd.com> Co-authored-by: root <root@useocpm2m-097-083.amd.com> Co-authored-by: knitcapcat <zejwang@amd.com>
There was a problem hiding this comment.
Pull request overview
Adds Prefill/Decode (P/D) KV-cache disaggregation support to ATOM, integrating a KV transfer connector (RDMA via MORI-IO), a routing proxy, and end-to-end plumbing through scheduling, runner execution, and the OpenAI-compatible server.
Changes:
- Introduces
atom/disaggregation/(KV transfer engine, KV output aggregation, proxy service) plus documentation. - Wires KV-transfer state through engine components (scheduler/sequence/model_runner/async worker manager) and exposes transfer metadata via the OpenAI API.
- Adds configuration/CLI support for
kv_transfer_configand a small unit test for KV output aggregation.
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
test/disaggregation/test_kv_aggregator.py |
Adds a basic test for cross-worker KV transfer completion aggregation. |
atom/utils/network.py |
Adds IP detection helper and env-var handling for host IP selection. |
atom/utils/forward_context.py |
Adds lazy global KV connector instances and registers KV caches into the connector. |
atom/model_ops/attentions/aiter_mla.py |
Adjusts decode metadata preparation for first-decode-after-remote-prefill behavior. |
atom/model_engine/sequence.py |
Adds new sequence status and KV-transfer-related fields. |
atom/model_engine/scheduler.py |
Integrates KV disaggregation into scheduling/postprocess and carries connector metadata in batches. |
atom/model_engine/request.py |
Extends streaming output object to include KV transfer output metadata. |
atom/model_engine/model_runner.py |
Adds worker-side dispatch of connector metadata and worker KV-finish reporting. |
atom/model_engine/llm_engine.py |
Parses KV transfer config and forwards per-request KV transfer params into sequences. |
atom/model_engine/engine_core.py |
Dispatches connector metadata to workers and aggregates KV completion across TP ranks each step. |
atom/model_engine/block_manager.py |
Introduces a constant intended for remote/prefill hashing coordination. |
atom/model_engine/async_proc.py |
Adds a dedicated KV output channel and aggregated RPC support across workers. |
atom/model_engine/arg_utils.py |
Adds CLI argument and default for --kv-transfer-config. |
atom/entrypoints/openai_server.py |
Accepts KV transfer params on requests and returns KV transfer metadata in responses/streaming chunks. |
atom/disaggregation/requirements.txt |
Adds disaggregation-specific Python dependencies list. |
atom/disaggregation/proxy.py |
Adds Quart-based proxy for service discovery + routing prefill/decode requests. |
atom/disaggregation/kvoutput_aggregator.py |
Adds aggregator to combine per-TP-worker finished send/recv signals. |
atom/disaggregation/kv_transfer_engine.py |
Adds MORI-IO based KV cache transfer connector (worker + scheduler components). |
atom/disaggregation/__init__.py |
Initializes the disaggregation package. |
atom/disaggregation/README.md |
Documents how to run proxy/prefill/decode for P/D disaggregation. |
atom/config.py |
Adds kv_transfer_config field and parsing for string-based configs. |
README.md |
Adds top-level note announcing P/D disaggregation support and doc link. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
| stream_gen = stream_decode_response(session, decode_response, request_id) | ||
| response = await make_response(stream_gen) | ||
| response.headers["Content-Type"] = "application/json; charset=utf-8" |
There was a problem hiding this comment.
The proxy streams bytes from the decode instance (which are SSE-formatted in the OpenAI streaming API), but forces Content-Type: application/json; charset=utf-8. This can break streaming clients expecting text/event-stream. Consider preserving the upstream content-type or explicitly setting text/event-stream for streaming responses.
| response.headers["Content-Type"] = "application/json; charset=utf-8" | |
| response.headers["Content-Type"] = "text/event-stream; charset=utf-8" |
|
|
||
| num_blocks_per_seq_bk = [ | ||
| (ctx + self.block_size - 1) // self.block_size for ctx in batch.context_lens | ||
| ] |
There was a problem hiding this comment.
num_blocks_per_seq_bk is computed but never used. Remove it (or use it) to avoid dead code and keep the decode metadata preparation easier to follow.
| num_blocks_per_seq_bk = [ | |
| (ctx + self.block_size - 1) // self.block_size for ctx in batch.context_lens | |
| ] |
| self.is_first_decode_without_local_prefill = [ | ||
| seq.is_first_decode for seq in seqs.values() | ||
| ] | ||
| self.temperatures = [seq.temperature for seq in seqs.values()] | ||
| self.context_lens = [seq.num_tokens for seq in seqs.values()] | ||
|
|
||
| # Build the flat scheduled-token array | ||
| offs = self.context_lens - self.num_rejected - self.num_scheduled_tokens | ||
| self.scheduled_tokens = np.empty(total_tokens_num, dtype=np.int32) |
There was a problem hiding this comment.
ScheduledBatch.__init__ sets self.context_lens to a Python list, but then computes offs = self.context_lens - self.num_rejected - self.num_scheduled_tokens, which will raise TypeError (list minus ndarray). Keep context_lens as a NumPy array (as before) or explicitly convert before doing vectorized arithmetic so batch construction works.
atom/model_engine/scheduler.py
Outdated
| if not getattr(seq, "is_first_decode", False): | ||
| self.block_manager.may_append(seq) | ||
|
|
||
| num_new_tokens = self.mtp_k + 1 |
There was a problem hiding this comment.
Decode scheduling calls self.block_manager.may_append(seq) without passing num_new_tokens (it defaults to 1), even though each decode step schedules num_new_tokens = self.mtp_k + 1 tokens. This can under-allocate KV blocks when mtp_k > 0 (or near a block boundary) and lead to out-of-bounds writes later. Pass the correct num_new_tokens to may_append (and ensure the first-decode path still guarantees enough blocks).
| if not getattr(seq, "is_first_decode", False): | |
| self.block_manager.may_append(seq) | |
| num_new_tokens = self.mtp_k + 1 | |
| num_new_tokens = self.mtp_k + 1 | |
| if not getattr(seq, "is_first_decode", False): | |
| self.block_manager.may_append(seq, num_new_tokens=num_new_tokens) |
atom/model_engine/model_runner.py
Outdated
| @torch.inference_mode() | ||
| def async_proc_aggregation(self) -> KVConnectorOutput: | ||
| """Collect finished send/recv status from the KV connector.""" | ||
| done_sending, done_recving = get_kvconnector().get_finished() |
There was a problem hiding this comment.
async_proc_aggregation unconditionally does get_kvconnector().get_finished(). When kv_transfer_config is not enabled (or the connector failed to initialize), get_kvconnector() can return None, causing an AttributeError. Guard for None and return an empty KVConnectorOutput so EngineCore can call aggregation safely even in non-disaggregation runs.
| done_sending, done_recving = get_kvconnector().get_finished() | |
| connector = get_kvconnector() | |
| if connector is None: | |
| # No KV connector available (e.g., non-disaggregation run); return an "empty" result. | |
| return KVConnectorOutput(finished_sending=False, finished_recving=False) | |
| done_sending, done_recving = connector.get_finished() |
atom/model_engine/engine_core.py
Outdated
| ) | ||
| self.input_thread.start() | ||
|
|
||
| self.kv_aggregator = KVOutputAggregator(world_size=config.tensor_parallel_size) |
There was a problem hiding this comment.
EngineCore.__init__ assigns self.kv_aggregator = KVOutputAggregator(...) but this field is never referenced (aggregation is done via runner_mgr.call_func_with_aggregation). Remove the unused instance variable or switch the code to use it to avoid dead state.
| @@ -103,6 +147,7 @@ def send_output_to_socket(self, addr: str, output_queue: queue.Queue): | |||
| socket.send(serialized_obj) | |||
There was a problem hiding this comment.
send_output_to_socket loops with while True: and blocks on output_queue.get() without checking still_running or consuming a sentinel, so AsyncIOProc.exit() cannot reliably stop/join these threads (especially now that an extra KV-output thread is created). Add a shutdown condition (e.g., check still_running and/or send a sentinel through the queue) so worker processes can exit cleanly.
atom/utils/network.py
Outdated
| import os | ||
| import socket | ||
| import warnings |
There was a problem hiding this comment.
New module atom/utils/network.py is missing the SPDX license header used throughout atom/utils/* (e.g., atom/utils/envs.py). Add the standard SPDX + copyright header for consistency and licensing clarity.
| if not (hasattr(config, "kv_transfer_config") and config.kv_transfer_config): | ||
| return _global_kvconnector | ||
|
|
||
| if role == "worker": | ||
| from aiter.dist.parallel_state import get_tp_group | ||
|
|
||
| try: | ||
| tp_rank = get_tp_group().rank_in_group | ||
| except Exception: | ||
| _logger.warning( | ||
| "get_tp_group() failed (dist not initialized?), returning None" | ||
| ) | ||
| return None | ||
|
|
||
| if _global_kvconnector is None: | ||
| from atom.disaggregation.kv_transfer_engine import KVConnector | ||
|
|
||
| _global_kvconnector = KVConnector(config) | ||
| _logger.debug("Initialized global KVConnector at tp_rank %d", tp_rank) | ||
|
|
||
| elif role == "scheduler": | ||
| from atom.disaggregation.kv_transfer_engine import KVConnectorScheduler | ||
|
|
||
| _global_kvconnector_scheduler = KVConnectorScheduler(config) | ||
| _logger.debug("Initialized global KVConnectorScheduler") | ||
| return _global_kvconnector_scheduler | ||
|
|
There was a problem hiding this comment.
get_kvconnector returns _global_kvconnector when KV transfer is not configured, even for role="scheduler". This makes the return value role-inconsistent and can leak a stale worker connector; callers also can’t distinguish “disabled” from “initialized earlier”. Return None when not configured, and return the role-appropriate global (_global_kvconnector_scheduler vs _global_kvconnector).
| ) | ||
|
|
||
| # --- Decode request --- | ||
| req_data["max_tokens"] -= 1 |
There was a problem hiding this comment.
Proxy assumes req_data["max_tokens"] exists and is an int (req_data["max_tokens"] -= 1). Since this proxy accepts raw JSON (no schema validation), requests without max_tokens (or with None) will raise. Use a safe default / validation and handle max_completion_tokens as well if you intend to support both OpenAI-style fields.
| req_data["max_tokens"] -= 1 | |
| max_tokens = req_data.get("max_tokens") | |
| max_completion_tokens = req_data.get("max_completion_tokens") | |
| effective_max_tokens = None | |
| if isinstance(max_tokens, int): | |
| effective_max_tokens = max_tokens | |
| elif isinstance(max_completion_tokens, int): | |
| effective_max_tokens = max_completion_tokens | |
| if effective_max_tokens is not None: | |
| # Reserve one token for prefill and ensure non-negative. | |
| decode_max_tokens = max(0, effective_max_tokens - 1) | |
| req_data["max_tokens"] = decode_max_tokens | |
| if isinstance(max_completion_tokens, int) and not isinstance(max_tokens, int): | |
| # Keep max_completion_tokens in sync if it was the source. | |
| req_data["max_completion_tokens"] = decode_max_tokens |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 22 out of 23 changed files in this pull request and generated 22 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| while True: | ||
| try: | ||
| identity, msg = sock.recv_multipart() | ||
| self._handle_message(msg) |
There was a problem hiding this comment.
async_wait_reqid() calls self._handle_message(msg), but there is no _handle_message method defined on MoRIIOWrapper (only _dispatch_message). If this code path is exercised it will raise AttributeError and break notifications. Rename the call to the correct handler or implement _handle_message.
| self._handle_message(msg) | |
| self._dispatch_message(msg) |
| def write_remote_data_single( | ||
| self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 | ||
| ): | ||
| assert self.local_memory_registered, "You have not register local memory data!" | ||
| assert self.moriio_engine is not None, "MoRIIO engine must be set first" | ||
| transfer_status = self.sessions[sess_idx].write( | ||
| local_offset, | ||
| remote_offset, | ||
| transfer_size_byte, | ||
| self.moriio_engine.allocate_transfer_uid(), | ||
| ) | ||
| with self.lock: | ||
| self.transfer_status.append(transfer_status) |
There was a problem hiding this comment.
write_remote_data_single() references self.sessions[sess_idx], but self.sessions is never initialized in MoRIIOWrapper, so this will raise AttributeError if called. Either initialize self.sessions (and document its lifecycle) or remove/replace this method if it’s not used.
| """Collect finished send/recv status from the KV connector.""" | ||
| connector = get_kvconnector() | ||
| if connector is None: | ||
| return KVConnectorOutput(finished_sending=[], finished_recving=[]) |
There was a problem hiding this comment.
When connector is None, async_proc_aggregation() returns KVConnectorOutput(finished_sending=[], finished_recving=[]). The KVConnectorOutput dataclass is defined with set[...] fields, so returning lists here is inconsistent and can lead to surprises if callers assume set semantics. Prefer returning empty sets (or relying on defaults) for both fields.
| return KVConnectorOutput(finished_sending=[], finished_recving=[]) | |
| return KVConnectorOutput(finished_sending=set(), finished_recving=set()) |
| # Phase 2: All workers have finished sending req_id="1" | ||
| worker_outputs = [ | ||
| KVConnectorOutput(finished_sending={"1"}) for _ in range(8) | ||
| ] | ||
| finished = aggregator.aggregate(worker_outputs) | ||
| print(f"Round 2 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}") | ||
| assert finished.finished_sending == {"1"} | ||
| assert finished.finished_recving == set() | ||
|
|
||
| # Phase 3: All workers have finished receiving req_id="1" | ||
| worker_outputs = [ | ||
| KVConnectorOutput(finished_recving={"1"}) for _ in range(8) | ||
| ] | ||
| finished = aggregator.aggregate(worker_outputs) | ||
| print(f"Round 3 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}") | ||
| assert finished.finished_sending == set() | ||
| assert finished.finished_recving == {"1"} | ||
|
|
||
| print("\nTest 2: Complex scenario") | ||
| aggregator2 = KVOutputAggregator(world_size=8) | ||
|
|
||
| # First round: Some workers complete some requests | ||
| worker_outputs = [ | ||
| KVConnectorOutput(finished_sending={"1", "2"}, finished_recving={"1"}), | ||
| KVConnectorOutput(finished_sending={"1"}), | ||
| KVConnectorOutput(finished_sending={"2"}, finished_recving={"2"}), | ||
| KVConnectorOutput(finished_sending={"1"}), | ||
| KVConnectorOutput(), # Empty output | ||
| KVConnectorOutput(finished_sending={"1"}, finished_recving={"3"}), | ||
| KVConnectorOutput(finished_sending={"2"}), | ||
| KVConnectorOutput(finished_sending={"1"}) |
There was a problem hiding this comment.
The test uses string request IDs (e.g. "1") for finished_sending/finished_recving, but the engine appears to use integer seq.id/transfer_id values for KV completion tracking. Using the same ID type in tests as production (likely int) would better catch type-mismatch bugs where completions never match scheduler sequence IDs.
| # Phase 2: All workers have finished sending req_id="1" | |
| worker_outputs = [ | |
| KVConnectorOutput(finished_sending={"1"}) for _ in range(8) | |
| ] | |
| finished = aggregator.aggregate(worker_outputs) | |
| print(f"Round 2 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}") | |
| assert finished.finished_sending == {"1"} | |
| assert finished.finished_recving == set() | |
| # Phase 3: All workers have finished receiving req_id="1" | |
| worker_outputs = [ | |
| KVConnectorOutput(finished_recving={"1"}) for _ in range(8) | |
| ] | |
| finished = aggregator.aggregate(worker_outputs) | |
| print(f"Round 3 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}") | |
| assert finished.finished_sending == set() | |
| assert finished.finished_recving == {"1"} | |
| print("\nTest 2: Complex scenario") | |
| aggregator2 = KVOutputAggregator(world_size=8) | |
| # First round: Some workers complete some requests | |
| worker_outputs = [ | |
| KVConnectorOutput(finished_sending={"1", "2"}, finished_recving={"1"}), | |
| KVConnectorOutput(finished_sending={"1"}), | |
| KVConnectorOutput(finished_sending={"2"}, finished_recving={"2"}), | |
| KVConnectorOutput(finished_sending={"1"}), | |
| KVConnectorOutput(), # Empty output | |
| KVConnectorOutput(finished_sending={"1"}, finished_recving={"3"}), | |
| KVConnectorOutput(finished_sending={"2"}), | |
| KVConnectorOutput(finished_sending={"1"}) | |
| # Phase 2: All workers have finished sending req_id=1 | |
| worker_outputs = [ | |
| KVConnectorOutput(finished_sending={1}) for _ in range(8) | |
| ] | |
| finished = aggregator.aggregate(worker_outputs) | |
| print(f"Round 2 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}") | |
| assert finished.finished_sending == {1} | |
| assert finished.finished_recving == set() | |
| # Phase 3: All workers have finished receiving req_id=1 | |
| worker_outputs = [ | |
| KVConnectorOutput(finished_recving={1}) for _ in range(8) | |
| ] | |
| finished = aggregator.aggregate(worker_outputs) | |
| print(f"Round 3 finished_sending: {finished.finished_sending}, finished_recving: {finished.finished_recving}") | |
| assert finished.finished_sending == set() | |
| assert finished.finished_recving == {1} | |
| print("\nTest 2: Complex scenario") | |
| aggregator2 = KVOutputAggregator(world_size=8) | |
| # First round: Some workers complete some requests | |
| worker_outputs = [ | |
| KVConnectorOutput(finished_sending={1, 2}, finished_recving={1}), | |
| KVConnectorOutput(finished_sending={1}), | |
| KVConnectorOutput(finished_sending={2}, finished_recving={2}), | |
| KVConnectorOutput(finished_sending={1}), | |
| KVConnectorOutput(), # Empty output | |
| KVConnectorOutput(finished_sending={1}, finished_recving={3}), | |
| KVConnectorOutput(finished_sending={2}), | |
| KVConnectorOutput(finished_sending={1}) |
| ): | ||
| seq = self.waiting[0] | ||
| # --- Prefill scheduling --- | ||
| while self.waiting and num_seqs_prefill < self.max_num_seqs: |
There was a problem hiding this comment.
The prefill scheduling loop no longer applies the delay_factor batching logic (_passed_delay()), which changes scheduling behavior and can reduce batching efficiency under load. If this is intentional for disaggregation, consider gating it only for the remote-KV path; otherwise reintroduce the delay condition to preserve existing throughput/latency tradeoffs.
| while self.waiting and num_seqs_prefill < self.max_num_seqs: | |
| while self.waiting and num_seqs_prefill < self.max_num_seqs: | |
| # For non-disaggregated setups, respect delay_factor-based batching | |
| # to preserve existing throughput/latency tradeoffs. The delay is | |
| # disabled when using a KV connector, since disaggregation already | |
| # governs prefill timing via remote KV availability. | |
| if self.kv_connector is None and not self._passed_delay(): | |
| break |
| @dataclass | ||
| class KVConnectorOutput: | ||
| """Per-worker snapshot of finished KV cache transfers. | ||
|
|
||
| Each TP worker produces one of these per scheduler step. The | ||
| :class:`KVOutputAggregator` combines them to determine which | ||
| request IDs have finished on *all* workers. | ||
|
|
||
| Attributes: | ||
| finished_sending: Request IDs whose KV send completed on this worker. | ||
| finished_recving: Request IDs whose KV receive completed on this worker. | ||
| """ | ||
|
|
||
| finished_sending: set[str] = field(default_factory=set) | ||
| finished_recving: set[str] = field(default_factory=set) | ||
|
|
There was a problem hiding this comment.
KVConnectorOutput uses set[str] for request IDs, but the rest of the engine uses integer sequence/request IDs. Standardizing this type (e.g., set[int] or a ReqId = int alias) would prevent hard-to-debug mismatches where finished IDs never match scheduler sequence IDs.
| world_size: int = field(init=False) | ||
| """world_size is TPxPP, it affects the number of workers we create.""" | ||
| data_parallel_master_port: int = 29500 | ||
| data_parallel_master_port: int = field(default_factory=get_open_port) | ||
| """Port of the data parallel master.""" | ||
|
|
||
| data_parallel_base_port: int = get_open_port() |
There was a problem hiding this comment.
data_parallel_master_port now defaults to get_open_port(). If each DP rank constructs its own Config independently (common in multi-process launchers), they may pick different master ports and fail to rendezvous. The previous fixed default (29500) avoided this class of mismatch. If the goal is to avoid port collisions, consider deriving a single deterministic port (e.g., from an env var like ATOM_DP_MASTER_PORT, or selecting once on rank 0 and broadcasting) rather than calling get_open_port() during dataclass construction on every rank.
| msgspec | ||
| msgpack | ||
| quart No newline at end of file |
There was a problem hiding this comment.
These dependencies are required by new runtime code (kv_transfer_engine.py imports msgspec/msgpack, proxy.py imports quart), but they are not included in the package install deps (pyproject.toml [project].dependencies). As-is, enabling disaggregation will likely fail with ModuleNotFoundError unless users install this extra file manually. Consider adding them to pyproject.toml (possibly as an optional extra) and referencing that in docs.
README.md
Outdated
|
|
||
| ## 📢 News | ||
|
|
||
| - **[2026/03]** ATOM now supports **Prefill/Decode (P/D) disaggregation** — run prefill and decode on separate GPU nodes with RDMA-based KV cache transfer via [MORI-IO](https://github.com/ROCm/mori). See [disaggregation docs](atom/disaggregation/README.md). |
There was a problem hiding this comment.
The link path appears incorrect: this PR adds disaggregation docs under atom/mesh/disaggregation/README.md, but the README links to atom/disaggregation/README.md. Update the link to the actual location so it doesn't 404.
| - **[2026/03]** ATOM now supports **Prefill/Decode (P/D) disaggregation** — run prefill and decode on separate GPU nodes with RDMA-based KV cache transfer via [MORI-IO](https://github.com/ROCm/mori). See [disaggregation docs](atom/disaggregation/README.md). | |
| - **[2026/03]** ATOM now supports **Prefill/Decode (P/D) disaggregation** — run prefill and decode on separate GPU nodes with RDMA-based KV cache transfer via [MORI-IO](https://github.com/ROCm/mori). See [disaggregation docs](atom/mesh/disaggregation/README.md). |
|
|
||
| num_blocks_per_seq_bk = [ | ||
| (ctx + self.block_size - 1) // self.block_size for ctx in batch.context_lens | ||
| ] | ||
| num_blocks_per_seq = [] | ||
| for i, (ctx, is_first) in enumerate(zip(batch.context_lens, batch.is_first_decode_without_local_prefill)): | ||
| if is_first: | ||
| # First decode after remote prefill: use pre-allocated block count | ||
| blocks = len(batch.block_tables[i]) | ||
| else: | ||
| # Normal case: ceil-divide context length by block size | ||
| blocks = (ctx + self.block_size - 1) // self.block_size | ||
| num_blocks_per_seq.append(blocks) | ||
| sum_blocks_before_converted = sum([(i + self.block_ratio - 1) // self.block_ratio for i in num_blocks_per_seq]) | ||
|
|
There was a problem hiding this comment.
num_blocks_per_seq is computed from batch.context_lens, but earlier in this method context_lens is adjusted (e.g., subtracting num_rejected for speculative decode) and block_tables may be truncated accordingly. Using the unadjusted batch.context_lens can make kv_indptr inconsistent with the actual block tables used for slot mapping, potentially breaking decode with speculative rejection. Prefer deriving block counts from the adjusted context_lens/block_tables used above.
| num_blocks_per_seq_bk = [ | |
| (ctx + self.block_size - 1) // self.block_size for ctx in batch.context_lens | |
| ] | |
| num_blocks_per_seq = [] | |
| for i, (ctx, is_first) in enumerate(zip(batch.context_lens, batch.is_first_decode_without_local_prefill)): | |
| if is_first: | |
| # First decode after remote prefill: use pre-allocated block count | |
| blocks = len(batch.block_tables[i]) | |
| else: | |
| # Normal case: ceil-divide context length by block size | |
| blocks = (ctx + self.block_size - 1) // self.block_size | |
| num_blocks_per_seq.append(blocks) | |
| sum_blocks_before_converted = sum([(i + self.block_ratio - 1) // self.block_ratio for i in num_blocks_per_seq]) | |
| # Derive the number of blocks per sequence from the prepared block tables | |
| # to ensure consistency with any adjustments (e.g., speculative rejection). | |
| num_blocks_per_seq = [len(bt) for bt in batch.block_tables] | |
| sum_blocks_before_converted = sum( | |
| [(i + self.block_ratio - 1) // self.block_ratio for i in num_blocks_per_seq] | |
| ) |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 30 out of 31 changed files in this pull request and generated 9 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not prefill_instances or not decode_instances: | ||
| return await make_response(( | ||
| "Service Unavailable: no prefill or decode instances registered.", | ||
| 503, | ||
| )) | ||
|
|
||
| # Round-robin instance selection | ||
| pid = request_nums % len(prefill_instances) | ||
| did = request_nums % len(decode_instances) | ||
| prefill_ep = prefill_instances[pid] | ||
| decode_ep = decode_instances[did] |
There was a problem hiding this comment.
handle_request() reads prefill_instances / decode_instances without _list_lock, but those lists are mutated concurrently by the registration listener thread. This can lead to inconsistent len() / indexing (e.g., IndexError), especially around pid/did selection. Wrap the availability check + endpoint selection in with _list_lock: (or copy the lists under lock before using them).
| if not prefill_instances or not decode_instances: | |
| return await make_response(( | |
| "Service Unavailable: no prefill or decode instances registered.", | |
| 503, | |
| )) | |
| # Round-robin instance selection | |
| pid = request_nums % len(prefill_instances) | |
| did = request_nums % len(decode_instances) | |
| prefill_ep = prefill_instances[pid] | |
| decode_ep = decode_instances[did] | |
| # Protect access to instance lists with the shared lock to avoid | |
| # races with the registration listener thread. | |
| with _list_lock: | |
| if not prefill_instances or not decode_instances: | |
| return await make_response(( | |
| "Service Unavailable: no prefill or decode instances registered.", | |
| 503, | |
| )) | |
| # Round-robin instance selection | |
| pid = request_nums % len(prefill_instances) | |
| did = request_nums % len(decode_instances) | |
| prefill_ep = prefill_instances[pid] | |
| decode_ep = decode_instances[did] |
| python -m atom.entrypoints.openai_server \ | ||
| --kv_cache_dtype fp8 \ | ||
| --model /path/to/model \ | ||
| --block-size 16 \ | ||
| -tp 8 \ | ||
| --enable-dp-attention \ | ||
| --enable-expert-parallel \ | ||
| --kv-transfer-config '{"kv_role":"kv_consumer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}' | ||
| ``` |
There was a problem hiding this comment.
This command example uses "http_prt" in --kv-transfer-config, but the code reads http_port. With the current spelling, the HTTP port override won’t take effect. Replace http_prt with http_port here.
atom/utils/network.py
Outdated
| import os | ||
| import socket | ||
| import warnings | ||
|
|
||
| def get_ip() -> str: | ||
| # Check environment variable first |
There was a problem hiding this comment.
atom/utils/network.py is missing the SPDX license/copyright header that appears to be standard across the codebase (e.g., atom/utils/envs.py:1-2, atom/utils/forward_context.py:1-3). Add the SPDX header here for consistency and automated license scanning.
| if not (hasattr(config, "kv_transfer_config") and config.kv_transfer_config): | ||
| return _global_kvconnector | ||
|
|
There was a problem hiding this comment.
get_kvconnector() returns _global_kvconnector whenever config.kv_transfer_config is missing/empty, regardless of the requested role. This can return the worker connector when the caller asked for the scheduler connector (or vice versa) if one was previously initialized. Consider returning _global_kvconnector_scheduler when role == "scheduler" (and _global_kvconnector for "worker"), even in the early-return path.
| ], | ||
| record_shapes=enable_detailed_profiling, | ||
| with_stack=enable_detailed_profiling, | ||
| with_stack=True, |
There was a problem hiding this comment.
start_profiler()’s docstring says with_stack is controlled by ATOM_PROFILER_MORE, but the profiler is created with with_stack=True unconditionally. Capturing stacks is expensive and can significantly increase profiling overhead and trace size; set with_stack=enable_detailed_profiling (or update the docstring if always-on is intentional).
| with_stack=True, | |
| with_stack=enable_detailed_profiling, |
| python -m atom.entrypoints.openai_server \ | ||
| --kv_cache_dtype fp8 \ | ||
| --model /path/to/model \ | ||
| --block-size 16 \ | ||
| -tp 8 \ | ||
| --enable-dp-attention \ | ||
| --enable-expert-parallel \ | ||
| --kv-transfer-config '{"kv_role":"kv_producer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}' | ||
| ``` |
There was a problem hiding this comment.
This command example uses "http_prt" in --kv-transfer-config, but the code reads http_port. With the current spelling, the HTTP port override won’t take effect. Replace http_prt with http_port here.
| # If a handshake was needed, spin until it completes then read. | ||
| while need_handshake: | ||
| if ( | ||
| self._ready_requests.empty() | ||
| and remote_engine_id not in self.load_ready_flag | ||
| ): | ||
| continue | ||
| elif ( | ||
| not self._ready_requests.empty() | ||
| and remote_engine_id in self.load_ready_flag | ||
| ): | ||
| self._issue_read_for_req(*self._ready_requests.get_nowait()) | ||
| break | ||
| else: | ||
| break | ||
|
|
There was a problem hiding this comment.
start_load_kv() busy-spins in while need_handshake: with a tight continue loop while waiting for handshake completion. This can peg a CPU core during handshake latency. Prefer a blocking wait (e.g., Queue.get() with timeout) or at least add a short sleep/backoff and a clear exit condition; also consider tracking per-engine futures/events instead of polling shared dicts.
| # If a handshake was needed, spin until it completes then read. | |
| while need_handshake: | |
| if ( | |
| self._ready_requests.empty() | |
| and remote_engine_id not in self.load_ready_flag | |
| ): | |
| continue | |
| elif ( | |
| not self._ready_requests.empty() | |
| and remote_engine_id in self.load_ready_flag | |
| ): | |
| self._issue_read_for_req(*self._ready_requests.get_nowait()) | |
| break | |
| else: | |
| break | |
| # If a handshake was needed, wait until it completes then read. | |
| if need_handshake: | |
| handshake_timeout = 5.0 # seconds | |
| start_time = time.time() | |
| while True: | |
| # Exit if handshake takes too long to avoid indefinite spinning. | |
| if time.time() - start_time > handshake_timeout: | |
| logger.warning( | |
| "Timed out waiting for KV handshake with remote engine %s " | |
| "after %.1f seconds", | |
| remote_engine_id, | |
| handshake_timeout, | |
| ) | |
| break | |
| if ( | |
| self._ready_requests.empty() | |
| and remote_engine_id not in self.load_ready_flag | |
| ): | |
| # Handshake not yet complete; back off briefly to avoid busy-spin. | |
| time.sleep(0.001) | |
| continue | |
| elif ( | |
| not self._ready_requests.empty() | |
| and remote_engine_id in self.load_ready_flag | |
| ): | |
| self._issue_read_for_req(*self._ready_requests.get_nowait()) | |
| break | |
| else: | |
| # Either handshake completed but no matching ready request, or | |
| # state changed unexpectedly; exit loop and let caller decide. | |
| break |
| logger.info( | ||
| f"LLMEngine init with {self.data_parallel_size} data parallel ranks" | ||
| ) |
There was a problem hiding this comment.
There are two identical logger.info calls for the data-parallel rank count, which will duplicate logs on every engine init. Remove one of them to keep logs clean.
| logger.info( | |
| f"LLMEngine init with {self.data_parallel_size} data parallel ranks" | |
| ) |
| # Skip block append for the first decode step after remote | ||
| # prefill — blocks were already allocated during prefill. | ||
| if not getattr(seq, "is_first_decode", False): | ||
| self.block_manager.may_append(seq) | ||
|
|
||
| num_new_tokens = self.mtp_k + 1 | ||
| self.block_manager.may_append(seq, num_new_tokens) | ||
| scheduled_seqs[seq.id] = seq | ||
| seq.type = SequenceType.DECODE | ||
| num_scheduled_tokens.append(num_new_tokens) | ||
|
|
||
| num_scheduled_tokens_np = num_scheduled_tokens | ||
| total_tokens_num_decode = sum(num_scheduled_tokens_np) | ||
|
|
||
| assert scheduled_seqs | ||
| self.running.extendleft(reversed(scheduled_seqs.values())) | ||
| # logger.info( | ||
| # f"Scheduled decode batch: {num_seqs_decode} reqs, {total_tokens_num_decode} tokens, req_ids: {tuple(scheduled_seqs.keys())}" | ||
| # ) | ||
| return ( | ||
| ScheduledBatch( | ||
| seqs=scheduled_seqs, | ||
| num_scheduled_tokens=num_scheduled_tokens_np, | ||
| total_tokens_num=total_tokens_num_decode, | ||
| total_tokens_num_decode=total_tokens_num_decode, | ||
| total_seqs_num=num_seqs_prefill + num_seqs_decode, | ||
| total_seqs_num_prefill=num_seqs_prefill, | ||
| total_seqs_num_decode=num_seqs_decode, | ||
| num_spec_step=self.mtp_k, | ||
| scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, | ||
| ), | ||
| scheduled_seqs, | ||
| seq.is_first_decode = False |
There was a problem hiding this comment.
seq.is_first_decode is being reset to False during decode scheduling before ScheduledBatch is constructed. Since ScheduledBatch.__init__ reads seq.is_first_decode to populate is_first_decode_without_local_prefill, the batch will never see True, breaking the “first-decode-after-remote-prefill” path (e.g., MLA decode metadata will compute block counts as if local prefill happened). Reset the flag only after the batch snapshot is created (or snapshot the flag into a local variable / batch field before mutating the sequence).
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 30 out of 31 changed files in this pull request and generated 15 comments.
Comments suppressed due to low confidence (2)
atom/model_engine/model_runner.py:15
torch.distributed as distis imported but not used anywhere in this file. If linting is enabled this will fail CI; otherwise it’s dead code and should be removed.
import torch
import torch.profiler as torch_profiler
import tqdm
atom/model_engine/llm_engine.py:201
Sequence(...)is now constructed without passingnum_draft_tokens=self.num_speculative_tokensandmamba_enabled=self.mamba_enabled. This will silently disable speculative decoding and mamba handling for all requests when KV disaggregation is enabled (and even when it isn’t, since the defaults are 0/False). Please pass the existing parameters through while addingkv_transfer_params.
if self.config.hf_config.model_type == "qwen3_next":
self.mamba_enabled = True
def preprocess(
self,
prompt_or_tokens: str | list[int],
sampling_params: SamplingParams,
stream_callback=None,
kv_transfer_params=None,
):
"""responsible for:
1) Tokenize
2) Create Sequence object"""
tokens = (
self.tokenizer.encode(prompt_or_tokens)
if isinstance(prompt_or_tokens, str)
else prompt_or_tokens
)
stop_token_sequences = []
if sampling_params.stop_strings:
stops = (
[sampling_params.stop_strings]
if isinstance(sampling_params.stop_strings, str)
else sampling_params.stop_strings
)
for stop_str in stops:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from atom.mesh.disaggregation import KVConnectorFactory | ||
|
|
||
| _global_kvconnector_scheduler = KVConnectorFactory.create_connector( | ||
| config, role="scheduler" | ||
| ) | ||
| _logger.debug("Initialized global KVConnectorScheduler") |
There was a problem hiding this comment.
In the role == "scheduler" branch, _global_kvconnector_scheduler is re-created on every call (no if _global_kvconnector_scheduler is None guard). That can leak resources / duplicate background threads/sockets if get_kvconnector("scheduler", ...) is called multiple times; make it truly lazy like the worker connector.
| from atom.mesh.disaggregation import KVConnectorFactory | |
| _global_kvconnector_scheduler = KVConnectorFactory.create_connector( | |
| config, role="scheduler" | |
| ) | |
| _logger.debug("Initialized global KVConnectorScheduler") | |
| if _global_kvconnector_scheduler is None: | |
| from atom.mesh.disaggregation import KVConnectorFactory | |
| _global_kvconnector_scheduler = KVConnectorFactory.create_connector( | |
| config, role="scheduler" | |
| ) | |
| _logger.debug("Initialized global KVConnectorScheduler") |
atom/model_engine/scheduler.py
Outdated
| # Skip block append for the first decode step after remote | ||
| # prefill — blocks were already allocated during prefill. | ||
| if not getattr(seq, "is_first_decode", False): | ||
| self.block_manager.may_append(seq) | ||
|
|
||
| num_new_tokens = self.mtp_k + 1 |
There was a problem hiding this comment.
Decode scheduling computes num_new_tokens = self.mtp_k + 1 but calls self.block_manager.may_append(seq) without passing num_new_tokens. For speculative decode (mtp_k>0) this can under-allocate KV blocks, since BlockManager.may_append uses num_new_tokens to decide how many blocks are needed.
| # Skip block append for the first decode step after remote | |
| # prefill — blocks were already allocated during prefill. | |
| if not getattr(seq, "is_first_decode", False): | |
| self.block_manager.may_append(seq) | |
| num_new_tokens = self.mtp_k + 1 | |
| num_new_tokens = self.mtp_k + 1 | |
| # Skip block append for the first decode step after remote | |
| # prefill — blocks were already allocated during prefill. | |
| if not getattr(seq, "is_first_decode", False): | |
| self.block_manager.may_append(seq, num_new_tokens=num_new_tokens) |
| # --- Prefill scheduling --- | ||
| while self.waiting and num_seqs_prefill < self.max_num_seqs: | ||
| seq = self.waiting.popleft() | ||
|
|
There was a problem hiding this comment.
The prefill loop no longer uses the delay_factor / _passed_delay() gating (but _passed_delay is still implemented). This is a behavior change that can regress batching efficiency and leaves dead code behind; either reintroduce the delay condition or remove the unused delay fields/method to keep scheduling logic consistent.
| remote_engine_id = f"{meta.remote_host}:{meta.remote_handshake_port}" | ||
| meta.remote_engine_id = remote_engine_id | ||
| dp0_id = self._engine_name_with_dp(remote_engine_id, 0) | ||
|
|
||
| if dp0_id not in self._remote_agents: | ||
| with self._handshake_lock: | ||
| if remote_engine_id not in self._remote_agents: | ||
| self._initiate_background_handshake( | ||
| req_id, remote_engine_id, meta | ||
| ) | ||
| need_handshake = True | ||
| continue | ||
|
|
There was a problem hiding this comment.
Handshake cache lookup uses dp0_id = engine_dp0 but inside the lock checks if remote_engine_id not in self._remote_agents:. Since _remote_agents is populated with DP-suffixed keys (e.g. ..._dp0), this condition will stay true and can trigger redundant handshakes for every request. The inner check should likely use dp0_id (or otherwise normalize keys consistently).
| EngineId = str | ||
| ReqId = str | ||
| TransferId = int | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Dataclasses | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @dataclass | ||
| class KVConnectorOutput: | ||
| """Per-worker snapshot of finished KV cache transfers. | ||
|
|
||
| Each TP worker produces one of these per scheduler step. The | ||
| :class:`KVOutputAggregator` combines them to determine which | ||
| request IDs have finished on *all* workers. | ||
|
|
||
| Attributes: | ||
| finished_sending: Request IDs whose KV send completed on this worker. | ||
| finished_recving: Request IDs whose KV receive completed on this worker. | ||
| expected_finished_count: How many finished notifications should be | ||
| expected per request (used by the aggregator). | ||
| """ | ||
|
|
||
| finished_sending: set[str] = field(default_factory=set) | ||
| finished_recving: set[str] = field(default_factory=set) | ||
| expected_finished_count: int = 0 |
There was a problem hiding this comment.
Type aliases and KVConnectorOutput currently use ReqId = str / set[str], but the scheduler/engine paths (e.g., Sequence.id, deferred_free_blocks, and POP_DONE_RECV parsing) operate on integer request IDs. This mismatch will confuse API consumers and can break equality checks if any path starts emitting strings. Consider standardizing on int request IDs here (or consistently converting to str at boundaries).
| def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: | ||
| # prefill | ||
| """Select the next batch of sequences for a forward pass. | ||
|
|
||
| Tries prefill first; if no new prefills are ready, falls back to | ||
| decoding already-running sequences. | ||
| """ | ||
| scheduled_seqs = {} | ||
| num_seqs_prefill = 0 | ||
| num_batched_tokens = 0 | ||
|
|
||
| skipped_waiting_requests: deque[Sequence] = deque() | ||
| num_scheduled_tokens: list[int] = [] | ||
| scheduled_spec_decode_tokens: dict[int, np.ndarray] = {} | ||
|
|
||
| if not self.running and not self.waiting: | ||
| # self.block_manager.reset() | ||
| return None | ||
|
|
There was a problem hiding this comment.
schedule() is annotated as returning tuple[ScheduledBatch, dict[int, Sequence]] but it returns None when there are no requests. Please update the return type to Optional[...] (and any callers) to match the actual behavior and avoid type-checking/contract confusion.
| async with aiohttp.ClientSession( | ||
| timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) |
There was a problem hiding this comment.
The aiohttp ClientTimeout(total=6 * 6000 * 6000) is ~2500 days, which is likely a unit bug and can leave hung requests/sockets around indefinitely under failure. Please replace with a reasonable bound (e.g. seconds/minutes) and/or make it configurable.
| async with aiohttp.ClientSession( | |
| timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) | |
| timeout_seconds = int(os.environ.get("DISAGG_PROXY_HTTP_TIMEOUT", "600")) | |
| async with aiohttp.ClientSession( | |
| timeout=aiohttp.ClientTimeout(total=timeout_seconds) |
| connector.start_load_kv(connector_meta_output) | ||
|
|
||
| @torch.inference_mode() | ||
| def async_proc_aggregation(self) -> KVConnectorOutput: |
There was a problem hiding this comment.
When the connector is absent, async_proc_aggregation returns KVConnectorOutput(finished_sending=[], finished_recving=[]), but the dataclass fields are sets. Returning KVConnectorOutput() (or sets) avoids type inconsistency and prevents downstream code from accidentally treating these as sequences with duplicates.
|
|
||
|
|
||
| def set_kv_cache_data( | ||
| kv_cache_data: dict[int, KVCacheTensor], config: Optional[Config] = None |
There was a problem hiding this comment.
set_kv_cache_data is typed as dict[int, KVCacheTensor], but ForwardContext.kv_cache_data is dict[str, KVCacheTensor] and the caller builds keys like "layer_{i}". Align the type annotation (and any downstream assumptions) to dict[str, KVCacheTensor] to avoid confusion and static type errors.
| kv_cache_data: dict[int, KVCacheTensor], config: Optional[Config] = None | |
| kv_cache_data: dict[str, KVCacheTensor], config: Optional[Config] = None |
| import json | ||
|
|
||
| kv_config_str = kwargs.get("kv_transfer_config", "{}") | ||
| try: | ||
| config.kv_transfer_config = json.loads(kv_config_str) | ||
| logger.info(f"KV transfer config loaded: {config.kv_transfer_config}") | ||
| except json.JSONDecodeError: | ||
| config.kv_transfer_config = {} |
There was a problem hiding this comment.
kv_transfer_config is already a Config field (and Config.__post_init__ parses string->dict). This extra json.loads(kwargs.get('kv_transfer_config', '{}')) is redundant and unsafe: if kv_transfer_config is already a dict it will raise TypeError (not caught), and it mutates config after CoreManager(config) has been constructed (so the engine core won’t see the updated value anyway). Consider removing this block and relying on Config to parse/hold KV config before CoreManager init.
| import json | |
| kv_config_str = kwargs.get("kv_transfer_config", "{}") | |
| try: | |
| config.kv_transfer_config = json.loads(kv_config_str) | |
| logger.info(f"KV transfer config loaded: {config.kv_transfer_config}") | |
| except json.JSONDecodeError: | |
| config.kv_transfer_config = {} |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 31 out of 32 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| logger.info( | ||
| f"LLMEngine init with {self.data_parallel_size} data parallel ranks" | ||
| ) |
There was a problem hiding this comment.
There are two identical logger.info("LLMEngine init with ...") calls back-to-back. This duplicates log spam; remove one of them.
| logger.info( | |
| f"LLMEngine init with {self.data_parallel_size} data parallel ranks" | |
| ) |
| } | ||
| # vllm use register_kv_caches to register kv_cache_data. We just set it to global here | ||
| set_kv_cache_data(kv_cache_data) | ||
| set_kv_cache_data(kv_cache_data, config) |
There was a problem hiding this comment.
set_kv_cache_data(kv_cache_data, config) uses config which is not defined in this scope. This will raise NameError during KV cache allocation. Pass self.config (or the intended config object) instead.
| set_kv_cache_data(kv_cache_data, config) | |
| set_kv_cache_data(kv_cache_data, self.config) |
| req_data["max_tokens"] -= 1 | ||
| req_data["kv_transfer_params"] = { |
There was a problem hiding this comment.
req_data["max_tokens"] -= 1 will raise if the client sends null/omits max_tokens (it is optional in the OpenAI request model). Guard against None and validate that max_tokens >= 1 before decrementing, otherwise return a 400 with a clear message.
| world_size: int = field(init=False) | ||
| """world_size is TPxPP, it affects the number of workers we create.""" | ||
| data_parallel_master_port: int = 29500 | ||
| data_parallel_master_port: int = field(default_factory=get_open_port) |
There was a problem hiding this comment.
ParallelConfig.data_parallel_master_port now defaults to get_open_port(). Since each process typically constructs its own Config, different DP ranks can pick different master ports, breaking rendezvous for stateless_init_torch_distributed_process_group (all ranks must use the same host:port). Prefer a deterministic default (e.g. 29500) and/or require this value to be provided via env/args so it is consistent across ranks.
| data_parallel_master_port: int = field(default_factory=get_open_port) | |
| data_parallel_master_port: int = int(os.getenv("ATOM_DATA_PARALLEL_MASTER_PORT", "29500")) |
atom/model_engine/scheduler.py
Outdated
| # Skip block append for the first decode step after remote | ||
| # prefill — blocks were already allocated during prefill. | ||
| if not getattr(seq, "is_first_decode", False): | ||
| self.block_manager.may_append(seq) | ||
|
|
||
| num_new_tokens = self.mtp_k + 1 |
There was a problem hiding this comment.
In decode scheduling, BlockManager.may_append is called without num_new_tokens, so it defaults to 1. When self.mtp_k > 0 (or any case where more than 1 token is scheduled), this can under-allocate KV blocks and cause out-of-blocks or incorrect block tables. Pass the computed num_new_tokens into may_append (while still skipping the first-decode-after-remote-prefill case).
| # Skip block append for the first decode step after remote | |
| # prefill — blocks were already allocated during prefill. | |
| if not getattr(seq, "is_first_decode", False): | |
| self.block_manager.may_append(seq) | |
| num_new_tokens = self.mtp_k + 1 | |
| num_new_tokens = self.mtp_k + 1 | |
| # Skip block append for the first decode step after remote | |
| # prefill — blocks were already allocated during prefill. | |
| if not getattr(seq, "is_first_decode", False): | |
| self.block_manager.may_append(seq, num_new_tokens) |
| if not (hasattr(config, "kv_transfer_config") and config.kv_transfer_config): | ||
| return _global_kvconnector | ||
|
|
There was a problem hiding this comment.
get_kvconnector returns _global_kvconnector when config is missing/disabled regardless of role. For role="scheduler", this can return the worker connector (or None) even if a scheduler connector was previously initialized. Consider returning _global_kvconnector_scheduler when role == "scheduler" (and/or branching the early return by role).
| _global_kvconnector_scheduler = KVConnectorFactory.create_connector( | ||
| config, role="scheduler" | ||
| ) | ||
| _logger.debug("Initialized global KVConnectorScheduler") |
There was a problem hiding this comment.
The scheduler connector branch recreates a new connector on every call (no _global_kvconnector_scheduler is None guard). This can leak resources/threads and lose internal state. Cache the scheduler connector similarly to the worker connector and only create it once per process.
| _global_kvconnector_scheduler = KVConnectorFactory.create_connector( | |
| config, role="scheduler" | |
| ) | |
| _logger.debug("Initialized global KVConnectorScheduler") | |
| if _global_kvconnector_scheduler is None: | |
| _global_kvconnector_scheduler = KVConnectorFactory.create_connector( | |
| config, role="scheduler" | |
| ) | |
| _logger.debug("Initialized global KVConnectorScheduler") |
| if ( | ||
| self._ready_requests.empty() | ||
| and remote_engine_id not in self.load_ready_flag | ||
| ): |
There was a problem hiding this comment.
The while need_handshake: loop busy-spins with continue when the handshake isn't ready yet, which can peg a CPU core. Add a small sleep/backoff (or use blocking Queue.get/Event) to avoid hot spinning while waiting for the handshake/ready queue.
| ): | |
| ): | |
| time.sleep(0.001) |
Key scheduling changes:
WAITING_FOR_REMOTE_KVS: sequences that have been allocated blocks but are waiting for remote KV data to arrive.ScheduledBatchcarriesconnector_meta_outputwhich is dispatched to workers to trigger async KV loading before the forward pass.WAITING_FOR_REMOTE_KVStoRUNNINGskip local block re-allocation (blocks were pre-allocated during the async load phase) and enter decode directly.Model Runner & Worker Pipeline
AsyncIOProcManageradds a separate per-worker ZMQ channel for KV transfer status, preventing it from mixing with regular forward outputs.KVOutputAggregatorcollectsfinished_sending/finished_recvingsignals from all TP ranks before reporting completion to the scheduler.process_kvconnector_outputdispatches connector metadata to trigger RDMA loads;async_proc_aggregationcollects transfer completion status. Both are called every engine step alongside the forward pass.Each engine step now follows an extended pipeline:
Front API Extensions
Proxy-based Request Orchestration
kv_transfer_engine
session lifecycle, handshake protocol, block-granularity KV read, async completion tracking, mori-io backend integration
Other Related Modifications
Since the
seq_idis defined based on the processing progress, the prefillseq_idis used as thetransfer_idto unify block management and release.AiterMLAMetadataBuilder decode metadata preparation is updated to handle the first-decode-after-remote-prefill case:
thanks for the help of @ZhangLirong-amd
Test Plan
accuracy:- [x] GSM8K 5-shot accuracy benchmark (1P1D, DeepSeek-V3, FP8, TP=8, concurrent=64)
performance:- [x] 1P2D, use
de5012b13a33596865f3c3c3dd1e31fadaaaccd8code baseModel: DeepSeek-V3 | 1P2D = 1 Prefill + 2 Decode nodes, all TP8 on MI300
TPOT ≤ 30 ms
TPOT ≤ 40 ms