-
Notifications
You must be signed in to change notification settings - Fork 42
feat: PD disaggregation #253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
eda7a2f
d5a747c
ab68bfc
27da046
a11325a
72fc4bb
30a2f2d
4e10fe8
a076f00
d1579bc
2e138b1
b625f27
fe420e8
2f8536c
fd4ec66
185be91
7e28839
96a8443
c8efdee
450be20
8b1dc6c
30efcad
e035f54
f46ea41
3d46cb5
43124b3
0a07b36
19f1125
45f4761
88507de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -603,7 +603,7 @@ class ParallelConfig: | |
| set only in SPMD mode.""" | ||
| world_size: int = field(init=False) | ||
| """world_size is TPxPP, it affects the number of workers we create.""" | ||
| data_parallel_master_port: int = 29500 | ||
| data_parallel_master_port: int = field(default_factory=get_open_port) | ||
| """Port of the data parallel master.""" | ||
|
|
||
| data_parallel_base_port: int = get_open_port() | ||
|
Comment on lines
604
to
609
|
||
|
|
@@ -780,6 +780,7 @@ class Config: | |
| enable_dp_attention: bool = False | ||
| torch_dtype: torch.dtype = field(init=False) | ||
| speculative_config: Optional[SpeculativeConfig] = None | ||
| kv_transfer_config: dict = field(default_factory=dict) | ||
|
|
||
| enable_tbo: bool = False | ||
| enable_tbo_decode: bool = False | ||
|
|
@@ -878,6 +879,18 @@ def __post_init__(self): | |
| else torch.bfloat16 | ||
| ) | ||
|
|
||
| if hasattr(self, "kv_transfer_config") and isinstance( | ||
| self.kv_transfer_config, str | ||
| ): | ||
| import json | ||
|
|
||
| try: | ||
| self.kv_transfer_config = json.loads(self.kv_transfer_config) | ||
| except json.JSONDecodeError: | ||
| import ast | ||
|
|
||
| self.kv_transfer_config = ast.literal_eval(self.kv_transfer_config) | ||
|
|
||
| if self.speculative_config is not None: | ||
| if self.speculative_config.num_speculative_tokens > 4: | ||
| raise ValueError( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # KV Cache Disaggregation (Prefill/Decode Separation) | ||
|
|
||
| Prefill/Decode (P/D) disaggregation runs the prefill and decode phases on | ||
| separate GPU instances. The prefill node computes KV caches and transfers | ||
| them to the decode node via RDMA, so the decode node can skip prefill | ||
| entirely and start generating tokens immediately. | ||
|
|
||
| ## MORI (Modular RDMA Interface) | ||
|
|
||
| The underlying KV cache transfer is powered by | ||
| [**MORI**](https://github.com/ROCm/mori) — a modular, high-performance | ||
| RDMA framework for GPU-centric communication on AMD platforms. | ||
|
|
||
| Specifically, this module uses **MORI-IO**, the point-to-point communication | ||
| library within MORI. MORI-IO provides: | ||
|
|
||
| - **GPU-direct RDMA** — data moves directly between GPU VRAM across nodes | ||
| without staging through host memory, minimizing latency and CPU overhead. | ||
| - **IBGDA (InfiniBand GPUDirect Async)** — RDMA operations are issued | ||
| directly from GPU kernels, bypassing the CPU entirely for the data path. | ||
| - **Session-based transfers** — MORI-IO pre-builds RDMA sessions (QP pairs, | ||
| memory registrations) during a one-time handshake. Subsequent transfers | ||
| reuse these sessions with near-zero setup cost. | ||
| - **Hardware support** — works with AMD MI300X/MI325X/MI355X GPUs and | ||
| ConnectX-7, Broadcom Thor2, and AMD Pollara (AINIC) NICs. | ||
|
|
||
| In the P/D disaggregation flow, the decode node uses MORI-IO to issue | ||
| RDMA READs against the prefill node's KV cache blocks. Each TP rank | ||
| independently reads its own KV slice, so the transfer is fully parallel | ||
| across the tensor-parallel group. | ||
|
|
||
| ``` | ||
| Client ──▶ Proxy (:10001) | ||
| │ | ||
| ▼ | ||
| Prefill Node (kv_producer) # 1. compute KV caches | ||
| │ | ||
| ▼ | ||
| Proxy # 2. receive block metadata | ||
| │ | ||
| ▼ | ||
| Decode Node (kv_consumer) # 3. RDMA read KV, generate tokens | ||
| │ | ||
| ▼ | ||
| Proxy ──▶ Client # 4. stream response back | ||
| ``` | ||
|
|
||
| ## How to Run | ||
|
|
||
| ### TP-only Mode (Tensor Parallelism) | ||
|
|
||
| #### 1. Start the Proxy | ||
|
|
||
| ```bash | ||
| python -m atom.kv_transfer.disaggregation.proxy | ||
| # or with custom port: | ||
| python -m atom.kv_transfer.disaggregation.proxy --port 10001 | ||
| ``` | ||
|
|
||
| #### 2. Start the Prefill Node | ||
|
|
||
| ```bash | ||
| python -m atom.entrypoints.openai_server \ | ||
| --kv_cache_dtype fp8 \ | ||
| --model /path/to/model \ | ||
| --block-size 16 \ | ||
| -tp 8 \ | ||
| --kv-transfer-config '{"kv_role":"kv_producer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}' | ||
| ``` | ||
|
Comment on lines
+62
to
+69
|
||
|
|
||
|
Comment on lines
+66
to
+70
|
||
| #### 3. Start the Decode Node | ||
|
|
||
| ```bash | ||
| python -m atom.entrypoints.openai_server \ | ||
| --kv_cache_dtype fp8 \ | ||
| --model /path/to/model \ | ||
| --block-size 16 \ | ||
| -tp 8 \ | ||
| --kv-transfer-config '{"kv_role":"kv_consumer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}' | ||
| ``` | ||
|
|
||
| #### 4. Send Requests (to the Proxy) | ||
|
|
||
| ```bash | ||
| curl -s http://<PROXY_IP>:10001/v1/completions \ | ||
| -H "Content-Type: application/json" \ | ||
| -d '{"prompt":"1 2 3 4 5","max_tokens":10,"temperature":0}' | ||
| ``` | ||
|
|
||
| ### DP + TP Mode (Data Parallelism + Tensor Parallelism) | ||
|
|
||
| When running MoE models (e.g. DeepSeek-V3/R1), you can enable data parallelism | ||
| with expert parallelism for higher throughput. Each DP rank runs a full TP | ||
| group, and MoE all-to-all is handled by MORI. | ||
|
|
||
| Key differences from TP-only mode: | ||
|
|
||
| - Add `--enable-dp-attention --enable-expert-parallel` to both prefill and | ||
| decode nodes. | ||
| - Set `MORI_SHMEM_MODE=ISOLATION` to separate MoRI (MoE all-to-all) and | ||
| MORI-IO (KV transfer) symmetric heap memory pools — without this, the two | ||
| subsystems compete for the same memory and cause OOM during warmup. | ||
| - The prefill node reports its `dp_rank` back to the proxy so the decode node | ||
| knows which DP rank's KV cache to read. | ||
| - Each decode DP rank binds MORI-IO sessions to **all** prefill DP ranks | ||
| (not just its own), because any prefill DP rank may have processed the | ||
| request. | ||
|
|
||
| #### 1. Start the Proxy | ||
|
|
||
| Same as TP-only: | ||
|
|
||
| ```bash | ||
| python -m atom.kv_transfer.disaggregation.proxy --port 10001 | ||
| ``` | ||
|
|
||
| #### 2. Start the Prefill Node | ||
|
|
||
| ```bash | ||
| export MORI_SHMEM_MODE=ISOLATION | ||
|
|
||
| python -m atom.entrypoints.openai_server \ | ||
| --kv_cache_dtype fp8 \ | ||
| --model /path/to/model \ | ||
| --block-size 16 \ | ||
| -tp 8 \ | ||
| --enable-dp-attention \ | ||
| --enable-expert-parallel \ | ||
| --kv-transfer-config '{"kv_role":"kv_producer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}' | ||
| ``` | ||
|
Comment on lines
+122
to
+130
|
||
|
|
||
| #### 3. Start the Decode Node | ||
|
|
||
| ```bash | ||
| export MORI_SHMEM_MODE=ISOLATION | ||
|
|
||
| python -m atom.entrypoints.openai_server \ | ||
| --kv_cache_dtype fp8 \ | ||
| --model /path/to/model \ | ||
| --block-size 16 \ | ||
| -tp 8 \ | ||
| --enable-dp-attention \ | ||
| --enable-expert-parallel \ | ||
| --kv-transfer-config '{"kv_role":"kv_consumer","proxy_ip":"<PROXY_IP>","proxy_ping_port":36367,"http_prt":8000}' | ||
| ``` | ||
|
Comment on lines
+137
to
+145
|
||
|
|
||
| #### 4. Send Requests | ||
|
|
||
| Same as TP-only — requests go through the proxy: | ||
|
|
||
| ```bash | ||
| curl -s http://<PROXY_IP>:10001/v1/completions \ | ||
| -H "Content-Type: application/json" \ | ||
| -d '{"prompt":"1 2 3 4 5","max_tokens":10,"temperature":0}' | ||
| ``` | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ParallelConfig.data_parallel_master_portnow defaults toget_open_port(). Since each process typically constructs its ownConfig, different DP ranks can pick different master ports, breaking rendezvous forstateless_init_torch_distributed_process_group(all ranks must use the same host:port). Prefer a deterministic default (e.g. 29500) and/or require this value to be provided via env/args so it is consistent across ranks.