Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
eda7a2f
Merge 0227 (#11)
inkcherry Mar 2, 2026
d5a747c
refine code
inkcherry Mar 2, 2026
ab68bfc
fix refine
inkcherry Mar 2, 2026
27da046
update comment
inkcherry Mar 2, 2026
a11325a
update deferred out config
inkcherry Mar 2, 2026
72fc4bb
update proxy
inkcherry Mar 2, 2026
30a2f2d
delte scripst
inkcherry Mar 2, 2026
4e10fe8
clean up
inkcherry Mar 2, 2026
a076f00
rename
inkcherry Mar 2, 2026
d1579bc
Merge off/main into atom_pd
inkcherry Mar 19, 2026
2e138b1
fix merge
inkcherry Mar 19, 2026
b625f27
fix non-pd dp path
inkcherry Mar 22, 2026
fe420e8
pd+dp gsm8k pass
inkcherry Mar 22, 2026
2f8536c
update readme
inkcherry Mar 22, 2026
fd4ec66
move the mesh floder
inkcherry Mar 22, 2026
185be91
refactor
inkcherry Mar 22, 2026
7e28839
fix refactor
inkcherry Mar 22, 2026
96a8443
update ut
inkcherry Mar 22, 2026
c8efdee
update ut
inkcherry Mar 22, 2026
450be20
fix ci
inkcherry Mar 22, 2026
8b1dc6c
update
inkcherry Mar 22, 2026
30efcad
format
inkcherry Mar 23, 2026
e035f54
format
inkcherry Mar 23, 2026
f46ea41
style: apply black formatting to pass CI checks
inkcherry Mar 23, 2026
3d46cb5
fix dsv3 mtp&qwen thinking ci
inkcherry Mar 23, 2026
43124b3
Merge remote-tracking branch 'origin/main' into pr253
inkcherry Mar 23, 2026
0a07b36
fix mtp
inkcherry Mar 24, 2026
19f1125
Merge main into atom_pd + rename mesh -> kv_transfer
inkcherry Apr 17, 2026
45f4761
Fix test_scheduler 3-tuple and revert kv_transfer_params on chat path
inkcherry Apr 17, 2026
88507de
style: black format kv_transfer imports after rename
inkcherry Apr 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

**ATOM** (AiTer Optimized Model) is a lightweight vLLM-like implementation, focusing on integration and optimization based on [AITER](https://github.com/ROCm/aiter).

## 📢 News

- **[2026/03]** ATOM now supports **Prefill/Decode (P/D) disaggregation** — run prefill and decode on separate GPU nodes with RDMA-based KV cache transfer via [MORI-IO](https://github.com/ROCm/mori). See [disaggregation docs](atom/kv_transfer/disaggregation/README.md).

## 🚀 Features

- **ROCm Optimized**: Built on AMD's ROCm platform with [AITER](https://github.com/ROCm/aiter) kernels (ASM, CK, Triton)
Expand Down
15 changes: 14 additions & 1 deletion atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ class ParallelConfig:
set only in SPMD mode."""
world_size: int = field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
data_parallel_master_port: int = 29500
data_parallel_master_port: int = field(default_factory=get_open_port)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParallelConfig.data_parallel_master_port now defaults to get_open_port(). Since each process typically constructs its own Config, different DP ranks can pick different master ports, breaking rendezvous for stateless_init_torch_distributed_process_group (all ranks must use the same host:port). Prefer a deterministic default (e.g. 29500) and/or require this value to be provided via env/args so it is consistent across ranks.

Suggested change
data_parallel_master_port: int = field(default_factory=get_open_port)
data_parallel_master_port: int = int(os.getenv("ATOM_DATA_PARALLEL_MASTER_PORT", "29500"))

Copilot uses AI. Check for mistakes.
"""Port of the data parallel master."""

data_parallel_base_port: int = get_open_port()
Comment on lines 604 to 609
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_parallel_master_port now defaults to get_open_port(). If each DP rank constructs its own Config independently (common in multi-process launchers), they may pick different master ports and fail to rendezvous. The previous fixed default (29500) avoided this class of mismatch. If the goal is to avoid port collisions, consider deriving a single deterministic port (e.g., from an env var like ATOM_DP_MASTER_PORT, or selecting once on rank 0 and broadcasting) rather than calling get_open_port() during dataclass construction on every rank.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -780,6 +780,7 @@ class Config:
enable_dp_attention: bool = False
torch_dtype: torch.dtype = field(init=False)
speculative_config: Optional[SpeculativeConfig] = None
kv_transfer_config: dict = field(default_factory=dict)

enable_tbo: bool = False
enable_tbo_decode: bool = False
Expand Down Expand Up @@ -878,6 +879,18 @@ def __post_init__(self):
else torch.bfloat16
)

if hasattr(self, "kv_transfer_config") and isinstance(
self.kv_transfer_config, str
):
import json

try:
self.kv_transfer_config = json.loads(self.kv_transfer_config)
except json.JSONDecodeError:
import ast

self.kv_transfer_config = ast.literal_eval(self.kv_transfer_config)

if self.speculative_config is not None:
if self.speculative_config.num_speculative_tokens > 4:
raise ValueError(
Expand Down
44 changes: 37 additions & 7 deletions atom/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,16 @@ def _send_stream_chunk_direct(
"finished_at": time.time(),
"started_at": started_at,
}
if getattr(request_output, "kv_transfer_params_output", None):
chunk_data["kv_transfer_params"] = request_output.kv_transfer_params_output
loop.call_soon_threadsafe(stream_queue.put_nowait, chunk_data)


async def generate_async(
prompt: str, sampling_params: SamplingParams, request_id: str
prompt: str,
sampling_params: SamplingParams,
request_id: str,
kv_transfer_params: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate text asynchronously for non-streaming requests."""
global engine, tokenizer
Expand All @@ -155,8 +160,13 @@ async def generate_async(
all_token_ids: List[int] = []
finish_reason: Optional[str] = None
seq = None
kv_transfer_output_meta_info = None

def completion_callback(request_output: RequestOutput):
nonlocal kv_transfer_output_meta_info
kv_transfer_output_meta_info = getattr(
request_output, "kv_transfer_params_output", None
)
now = time.time()
loop.call_soon_threadsafe(
token_queue.put_nowait,
Expand All @@ -170,7 +180,10 @@ def completion_callback(request_output: RequestOutput):

def do_preprocess():
return engine.io_processor.preprocess(
prompt, sampling_params, stream_callback=completion_callback
prompt,
sampling_params,
stream_callback=completion_callback,
kv_transfer_params=kv_transfer_params,
)

seq = await loop.run_in_executor(None, do_preprocess)
Expand Down Expand Up @@ -204,7 +217,7 @@ def do_preprocess():
else 0.0
)

yield {
response = {
"text": text,
"token_ids": all_token_ids,
"finish_reason": finish_reason,
Expand All @@ -214,6 +227,9 @@ def do_preprocess():
"tpot": tpot,
"latency": latency,
}
if kv_transfer_output_meta_info is not None:
response["kv_transfer_output_meta_info"] = kv_transfer_output_meta_info
yield response


def validate_model(requested_model: Optional[str]) -> None:
Expand All @@ -227,7 +243,10 @@ def validate_model(requested_model: Optional[str]) -> None:


async def setup_streaming_request(
prompt: str, sampling_params: SamplingParams, request_id: str
prompt: str,
sampling_params: SamplingParams,
request_id: str,
kv_transfer_params: Optional[Dict[str, Any]] = None,
) -> Tuple[int, asyncio.Queue]:
"""Set up a streaming request with the engine."""
global engine, _stream_queues, _seq_id_to_request_id
Expand All @@ -246,7 +265,10 @@ def stream_callback(request_output: RequestOutput) -> None:

def do_preprocess():
seq = engine.io_processor.preprocess(
prompt, sampling_params, stream_callback=stream_callback
prompt,
sampling_params,
stream_callback=stream_callback,
kv_transfer_params=kv_transfer_params,
)
_seq_id_to_request_id[seq.id] = request_id
return seq
Expand Down Expand Up @@ -427,7 +449,10 @@ async def completions(request: CompletionRequest):
# Streaming
if request.stream:
seq_id, stream_queue = await setup_streaming_request(
request.prompt, sampling_params, request_id
request.prompt,
sampling_params,
request_id,
kv_transfer_params=request.kv_transfer_params,
)
gen = stream_completion_response(
request_id,
Expand All @@ -445,7 +470,12 @@ async def completions(request: CompletionRequest):

# Non-streaming
final_output = None
async for output in generate_async(request.prompt, sampling_params, request_id):
async for output in generate_async(
request.prompt,
sampling_params,
request_id,
kv_transfer_params=request.kv_transfer_params,
):
final_output = output

if final_output is None:
Expand Down
4 changes: 4 additions & 0 deletions atom/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class CompletionRequest(BaseModel):
stop: Optional[List[str]] = None
ignore_eos: Optional[bool] = False
stream: Optional[bool] = False
# Optional KV-transfer metadata for P/D disaggregation.
kv_transfer_params: Optional[Dict[str, Any]] = None


# ============================================================================
Expand Down Expand Up @@ -143,6 +145,8 @@ class CompletionResponse(BaseModel):
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, Any]
# Optional KV-transfer metadata returned for P/D disaggregation.
kv_transfer_params: Optional[Dict[str, Any]] = None


class ModelCard(BaseModel):
Expand Down
16 changes: 15 additions & 1 deletion atom/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def create_completion_chunk(
text: str,
finish_reason: Optional[str] = None,
usage: Optional[Dict] = None,
**extra_fields: Any,
) -> str:
"""Create a text completion chunk in SSE format."""
chunk = {
Expand All @@ -40,6 +41,7 @@ def create_completion_chunk(
}
],
}
chunk.update(extra_fields)
if usage is not None:
chunk["usage"] = usage
return f"data: {json.dumps(chunk)}\n\n"
Expand All @@ -63,11 +65,16 @@ async def stream_completion_response(
new_text = chunk_data["text"]
num_tokens_output += len(chunk_data.get("token_ids", []))

extra_fields: Dict[str, Any] = {}
if "kv_transfer_params" in chunk_data:
extra_fields["kv_transfer_params"] = chunk_data["kv_transfer_params"]

yield create_completion_chunk(
request_id,
model,
new_text,
finish_reason=chunk_data.get("finish_reason"),
**extra_fields,
)

if chunk_data.get("finished", False):
Expand Down Expand Up @@ -99,7 +106,7 @@ def build_completion_response(
final_output: Dict[str, Any],
) -> CompletionResponse:
"""Build a non-streaming text completion response."""
return CompletionResponse(
response = CompletionResponse(
id=request_id,
created=int(time.time()),
model=model,
Expand All @@ -120,3 +127,10 @@ def build_completion_response(
"latency_s": round(final_output.get("latency", 0.0), 4),
},
)
if "kv_transfer_output_meta_info" in final_output:
response = response.model_copy(
update={
"kv_transfer_params": final_output["kv_transfer_output_meta_info"],
}
)
return response
Empty file added atom/kv_transfer/__init__.py
Empty file.
155 changes: 155 additions & 0 deletions atom/kv_transfer/disaggregation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# KV Cache Disaggregation (Prefill/Decode Separation)

Prefill/Decode (P/D) disaggregation runs the prefill and decode phases on
separate GPU instances. The prefill node computes KV caches and transfers
them to the decode node via RDMA, so the decode node can skip prefill
entirely and start generating tokens immediately.

## MORI (Modular RDMA Interface)

The underlying KV cache transfer is powered by
[**MORI**](https://github.com/ROCm/mori) — a modular, high-performance
RDMA framework for GPU-centric communication on AMD platforms.

Specifically, this module uses **MORI-IO**, the point-to-point communication
library within MORI. MORI-IO provides:

- **GPU-direct RDMA** — data moves directly between GPU VRAM across nodes
without staging through host memory, minimizing latency and CPU overhead.
- **IBGDA (InfiniBand GPUDirect Async)** — RDMA operations are issued
directly from GPU kernels, bypassing the CPU entirely for the data path.
- **Session-based transfers** — MORI-IO pre-builds RDMA sessions (QP pairs,
memory registrations) during a one-time handshake. Subsequent transfers
reuse these sessions with near-zero setup cost.
- **Hardware support** — works with AMD MI300X/MI325X/MI355X GPUs and
ConnectX-7, Broadcom Thor2, and AMD Pollara (AINIC) NICs.

In the P/D disaggregation flow, the decode node uses MORI-IO to issue
RDMA READs against the prefill node's KV cache blocks. Each TP rank
independently reads its own KV slice, so the transfer is fully parallel
across the tensor-parallel group.

```
Client ──▶ Proxy (:10001)
Prefill Node (kv_producer) # 1. compute KV caches
Proxy # 2. receive block metadata
Decode Node (kv_consumer) # 3. RDMA read KV, generate tokens
Proxy ──▶ Client # 4. stream response back
```

## How to Run

### TP-only Mode (Tensor Parallelism)

#### 1. Start the Proxy

```bash
python -m atom.kv_transfer.disaggregation.proxy
# or with custom port:
python -m atom.kv_transfer.disaggregation.proxy --port 10001
```

#### 2. Start the Prefill Node

```bash
python -m atom.entrypoints.openai_server \
--kv_cache_dtype fp8 \
--model /path/to/model \
--block-size 16 \
-tp 8 \
--kv-transfer-config '{"kv_role":"kv_producer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}'
```
Comment on lines +62 to +69
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sample --kv-transfer-config JSON uses http_prt, but the implementation reads http_port (see KVConnector.__init__). As documented, the HTTP port setting will be ignored, which is confusing during deployment. Update the docs to use the correct key name.

Copilot uses AI. Check for mistakes.

Comment on lines +66 to +70
#### 3. Start the Decode Node

```bash
python -m atom.entrypoints.openai_server \
--kv_cache_dtype fp8 \
--model /path/to/model \
--block-size 16 \
-tp 8 \
--kv-transfer-config '{"kv_role":"kv_consumer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}'
```

#### 4. Send Requests (to the Proxy)

```bash
curl -s http://<PROXY_IP>:10001/v1/completions \
-H "Content-Type: application/json" \
-d '{"prompt":"1 2 3 4 5","max_tokens":10,"temperature":0}'
```

### DP + TP Mode (Data Parallelism + Tensor Parallelism)

When running MoE models (e.g. DeepSeek-V3/R1), you can enable data parallelism
with expert parallelism for higher throughput. Each DP rank runs a full TP
group, and MoE all-to-all is handled by MORI.

Key differences from TP-only mode:

- Add `--enable-dp-attention --enable-expert-parallel` to both prefill and
decode nodes.
- Set `MORI_SHMEM_MODE=ISOLATION` to separate MoRI (MoE all-to-all) and
MORI-IO (KV transfer) symmetric heap memory pools — without this, the two
subsystems compete for the same memory and cause OOM during warmup.
- The prefill node reports its `dp_rank` back to the proxy so the decode node
knows which DP rank's KV cache to read.
- Each decode DP rank binds MORI-IO sessions to **all** prefill DP ranks
(not just its own), because any prefill DP rank may have processed the
request.

#### 1. Start the Proxy

Same as TP-only:

```bash
python -m atom.kv_transfer.disaggregation.proxy --port 10001
```

#### 2. Start the Prefill Node

```bash
export MORI_SHMEM_MODE=ISOLATION

python -m atom.entrypoints.openai_server \
--kv_cache_dtype fp8 \
--model /path/to/model \
--block-size 16 \
-tp 8 \
--enable-dp-attention \
--enable-expert-parallel \
--kv-transfer-config '{"kv_role":"kv_producer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}'
```
Comment on lines +122 to +130
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This command example uses "http_prt" in --kv-transfer-config, but the code reads http_port. With the current spelling, the HTTP port override won’t take effect. Replace http_prt with http_port here.

Copilot uses AI. Check for mistakes.

#### 3. Start the Decode Node

```bash
export MORI_SHMEM_MODE=ISOLATION

python -m atom.entrypoints.openai_server \
--kv_cache_dtype fp8 \
--model /path/to/model \
--block-size 16 \
-tp 8 \
--enable-dp-attention \
--enable-expert-parallel \
--kv-transfer-config '{"kv_role":"kv_consumer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}'
```
Comment on lines +137 to +145
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This command example uses "http_prt" in --kv-transfer-config, but the code reads http_port. With the current spelling, the HTTP port override won’t take effect. Replace http_prt with http_port here.

Copilot uses AI. Check for mistakes.

#### 4. Send Requests

Same as TP-only — requests go through the proxy:

```bash
curl -s http://<PROXY_IP>:10001/v1/completions \
-H "Content-Type: application/json" \
-d '{"prompt":"1 2 3 4 5","max_tokens":10,"temperature":0}'
```
Loading
Loading